README.md 59.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# PyTorch Pretrained BERT: The Big and Extending Repository of (pre-trained) Transformers
VictorSanh's avatar
VictorSanh committed
2

Julien Chaumond's avatar
Julien Chaumond committed
3
4
[![CircleCI](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT.svg?style=svg)](https://circleci.com/gh/huggingface/pytorch-pretrained-BERT)

thomwolf's avatar
thomwolf committed
5
This repository contains op-for-op PyTorch reimplementations, pre-trained models and fine-tuning examples for:
VictorSanh's avatar
VictorSanh committed
6

thomwolf's avatar
thomwolf committed
7
8
9
10
11
12
13
14
15
- [Google's BERT model](https://github.com/google-research/bert),
- [OpenAI's GPT model](https://github.com/openai/finetune-transformer-lm), and
- [Google/CMU's Transformer-XL model](https://github.com/kimiyoung/transformer-xl).

These implementations have been tested on several datasets (see the examples) and should match the performances of the associated TensorFlow implementations (e.g. ~91 F1 on SQuAD for BERT, ~88 F1 on RocStories for OpenAI GPT and ~18.3 perplexity on WikiText 103 for the Transformer-XL). You can find more details in the [Examples](#examples) section below.

Here are some information on these models:

**BERT** was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
thomwolf's avatar
thomwolf committed
16
17
This PyTorch implementation of BERT is provided with [Google's pre-trained models](https://github.com/google-research/bert), examples, notebooks and a command-line interface to load any pre-trained TensorFlow checkpoint for BERT is also provided.

thomwolf's avatar
thomwolf committed
18
**OpenAI GPT** was released together with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
thomwolf's avatar
thomwolf committed
19
This PyTorch implementation of OpenAI GPT is an adaptation of the [PyTorch implementation by HuggingFace](https://github.com/huggingface/pytorch-openai-transformer-lm) and is provided with [OpenAI's pre-trained model](https://github.com/openai/finetune-transformer-lm) and a command-line interface that was used to convert the pre-trained NumPy checkpoint in PyTorch.
thomwolf's avatar
thomwolf committed
20
21

**Google/CMU's Transformer-XL** was released together with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](http://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
thomwolf's avatar
thomwolf committed
22
This PyTorch implementation of Transformer-XL is an adaptation of the original [PyTorch implementation](https://github.com/kimiyoung/transformer-xl) which has been slightly modified to match the performances of the TensforFlow implementation and allow to re-use the pretrained weights. A command-line interface is provided to convert TensorFlow checkpoints in PyTorch models.
23

thomwolf's avatar
thomwolf committed
24
## Content
25

thomwolf's avatar
thomwolf committed
26
| Section | Description |
thomwolf's avatar
thomwolf committed
27
|-|-|
thomwolf's avatar
thomwolf committed
28
29
30
31
32
33
| [Installation](#installation) | How to install the package |
| [Overview](#overview) | Overview of the package |
| [Usage](#usage) | Quickstart examples |
| [Doc](#doc) |  Detailed documentation |
| [Examples](#examples) | Detailed examples on how to fine-tune Bert |
| [Notebooks](#notebooks) | Introduction on the provided Jupyter Notebooks |
thomwolf's avatar
thomwolf committed
34
| [TPU](#tpu) | Notes on TPU support and pretraining scripts |
thomwolf's avatar
thomwolf committed
35
| [Command-line interface](#Command-line-interface) | Convert a TensorFlow checkpoint in a PyTorch dump |
thomwolf's avatar
thomwolf committed
36

thomwolf's avatar
thomwolf committed
37
## Installation
VictorSanh's avatar
VictorSanh committed
38

thomwolf's avatar
thomwolf committed
39
This repo was tested on Python 2.7 and 3.5+ (examples are tested only on python 3.5+) and PyTorch 0.4.1/1.0.0
VictorSanh's avatar
VictorSanh committed
40

thomwolf's avatar
thomwolf committed
41
### With pip
thomwolf's avatar
thomwolf committed
42

thomwolf's avatar
thomwolf committed
43
44
PyTorch pretrained bert can be installed by pip as follows:
```bash
Joel Grus's avatar
Joel Grus committed
45
pip install pytorch-pretrained-bert
thomwolf's avatar
thomwolf committed
46
```
VictorSanh's avatar
VictorSanh committed
47

thomwolf's avatar
thomwolf committed
48
49
50
51
52
53
If you want to use the tokenizer associated to the `OpenAI GPT` tokenizer, you will need to install `ftfy` (if you are using Python 2, version 4.4.3 is the last version working for you) and `SpaCy` :
```bash
pip install spacy ftfy==4.4.3
python -m spacy download en
```

thomwolf's avatar
thomwolf committed
54
### From source
thomwolf's avatar
thomwolf committed
55
56
57
58
59

Clone the repository and run:
```bash
pip install [--editable] .
```
VictorSanh's avatar
VictorSanh committed
60

thomwolf's avatar
thomwolf committed
61
62
63
64
65
66
67
Here also, if you want to use `OpenAIGPT` tokenizer, you will need to install `ftfy` (limit to version 4.4.3 if you are using Python 2) and `SpaCy` :
```bash
pip install spacy ftfy==4.4.3
python -m spacy download en
```


thomwolf's avatar
thomwolf committed
68
A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`).
VictorSanh's avatar
VictorSanh committed
69

thomwolf's avatar
thomwolf committed
70
71
72
You can run the tests with the command:
```bash
python -m pytest -sv tests/
VictorSanh's avatar
VictorSanh committed
73
74
```

thomwolf's avatar
thomwolf committed
75
## Overview
thomwolf's avatar
thomwolf committed
76

thomwolf's avatar
thomwolf committed
77
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
thomwolf's avatar
thomwolf committed
78

thomwolf's avatar
thomwolf committed
79
- Eight **Bert** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
Girishkumar's avatar
Girishkumar committed
80
81
82
83
84
85
86
87
  - [`BertModel`](./pytorch_pretrained_bert/modeling.py#L556) - raw BERT Transformer model (**fully pre-trained**),
  - [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L710) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
  - [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L771) - BERT Transformer with the pre-trained next sentence prediction classifier on top  (**fully pre-trained**),
  - [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L639) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
  - [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L833) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
  - [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L899) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
  - [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L969) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
  - [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1034) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
Thomas Wolf's avatar
Thomas Wolf committed
88

thomwolf's avatar
thomwolf committed
89
- Three **OpenAI GPT** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
thomwolf's avatar
thomwolf committed
90
91
92
93
  - [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**),
  - [`OpenAIGPTLMHeadModel`](./pytorch_pretrained_bert/modeling_openai.py#L691) - OpenAI GPT Transformer with the tied language modeling head on top (**fully pre-trained**),
  - [`OpenAIGPTDoubleHeadsModel`](./pytorch_pretrained_bert/modeling_openai.py#L752) - OpenAI GPT Transformer with the tied language modeling head and a multiple choice classification head on top (OpenAI GPT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),

thomwolf's avatar
thomwolf committed
94
95
96
97
- Two **Transformer-XL** PyTorch models (`torch.nn.Module`) with pre-trained weights (in the [`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py) file):
  - [`TransfoXLModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L974) - Transformer-XL model which outputs the last hidden state and memory cells (**fully pre-trained**),
  - [`TransfoXLLMHeadModel`](./pytorch_pretrained_bert/modeling_transfo_xl.py#L1236) - Transformer-XL with the tied adaptive softmax head on top for language modeling which outputs the logits/loss and memory cells (**fully pre-trained**),

thomwolf's avatar
thomwolf committed
98
- Tokenizers for **BERT** (using word-piece) (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
thomwolf's avatar
thomwolf committed
99
100
101
102
  - `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
  - `WordpieceTokenizer` - WordPiece tokenization,
  - `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.

thomwolf's avatar
thomwolf committed
103
- Tokenizer for **OpenAI GPT** (using Byte-Pair-Encoding) (in the [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) file):
thomwolf's avatar
thomwolf committed
104
105
106
107
  - `OpenAIGPTTokenizer` - perform Byte-Pair-Encoding (BPE) tokenization.

- Tokenizer for **Transformer-XL** (word tokens ordered by frequency for adaptive softmax) (in the [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) file):
  - `OpenAIGPTTokenizer` - perform word tokenization and can order words by frequency in a corpus for use in an adaptive softmax.
thomwolf's avatar
thomwolf committed
108

thomwolf's avatar
thomwolf committed
109
- Optimizer for **BERT** (in the [`optimization.py`](./pytorch_pretrained_bert/optimization.py) file):
thomwolf's avatar
thomwolf committed
110
  - `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
thomwolf's avatar
thomwolf committed
111

thomwolf's avatar
thomwolf committed
112
- Optimizer for **OpenAI GPT** (in the [`optimization_openai.py`](./pytorch_pretrained_bert/optimization_openai.py) file):
thomwolf's avatar
thomwolf committed
113
114
  - `OpenAIGPTAdam` - OpenAI GPT version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.

thomwolf's avatar
thomwolf committed
115
- Configuration classes for BERT, OpenAI GPT and Transformer-XL (in the respective [`modeling.py`](./pytorch_pretrained_bert/modeling.py), [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py), [`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py) files):
Julien Chaumond's avatar
Julien Chaumond committed
116
  - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files.
thomwolf's avatar
thomwolf committed
117
  - `OpenAIGPTConfig` - Configuration class to store the configuration of a `OpenAIGPTModel` with utilities to read and write from JSON configuration files.
thomwolf's avatar
thomwolf committed
118
  - `TransfoXLConfig` - Configuration class to store the configuration of a `TransfoXLModel` with utilities to read and write from JSON configuration files.
thomwolf's avatar
thomwolf committed
119

thomwolf's avatar
thomwolf committed
120
121
The repository further comprises:

thomwolf's avatar
thomwolf committed
122
- Five examples on how to use **BERT** (in the [`examples` folder](./examples)):
thomwolf's avatar
thomwolf committed
123
124
  - [`extract_features.py`](./examples/extract_features.py) - Show how to extract hidden states from an instance of `BertModel`,
  - [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
thomwolf's avatar
thomwolf committed
125
  - [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 and SQuAD v2.0 tasks.
126
  - [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
Davide Fiocco's avatar
Davide Fiocco committed
127
  - [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining' on a target text corpus.  
thomwolf's avatar
thomwolf committed
128
129

- One example on how to use **OpenAI GPT** (in the [`examples` folder](./examples)):
Thomas Wolf's avatar
Thomas Wolf committed
130
  - [`run_openai_gpt.py`](./examples/run_openai_gpt.py) - Show how to fine-tune an instance of `OpenGPTDoubleHeadsModel` on the RocStories task.
thomwolf's avatar
thomwolf committed
131

Thomas Wolf's avatar
Thomas Wolf committed
132
133
- One example on how to use **Transformer-XL** (in the [`examples` folder](./examples)):
  - [`run_transfo_xl.py`](./examples/run_transfo_xl.py) - Show how to load and evaluate a pre-trained model of `TransfoXLLMHeadModel` on WikiText 103.
thomwolf's avatar
thomwolf committed
134

thomwolf's avatar
thomwolf committed
135
  These examples are detailed in the [Examples](#examples) section of this readme.
thomwolf's avatar
thomwolf committed
136
137
138
139
140
141

- Three notebooks that were used to check that the TensorFlow and PyTorch models behave identically (in the [`notebooks` folder](./notebooks)):
  - [`Comparing-TF-and-PT-models.ipynb`](./notebooks/Comparing-TF-and-PT-models.ipynb) - Compare the hidden states predicted by `BertModel`,
  - [`Comparing-TF-and-PT-models-SQuAD.ipynb`](./notebooks/Comparing-TF-and-PT-models-SQuAD.ipynb) - Compare the spans predicted by  `BertForQuestionAnswering` instances,
  - [`Comparing-TF-and-PT-models-MLM-NSP.ipynb`](./notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb) - Compare the predictions of the `BertForPretraining` instances.

thomwolf's avatar
thomwolf committed
142
  These notebooks are detailed in the [Notebooks](#notebooks) section of this readme.
thomwolf's avatar
thomwolf committed
143

thomwolf's avatar
thomwolf committed
144
- A command-line interface to convert TensorFlow checkpoints (BERT, Transformer-XL) or NumPy checkpoint (OpenAI) in a PyTorch save of the associated PyTorch model:
thomwolf's avatar
thomwolf committed
145

thomwolf's avatar
thomwolf committed
146
  This CLI is detailed in the [Command-line interface](#Command-line-interface) section of this readme.
thomwolf's avatar
thomwolf committed
147
148

## Usage
thomwolf's avatar
thomwolf committed
149

thomwolf's avatar
thomwolf committed
150
151
### BERT

thomwolf's avatar
thomwolf committed
152
Here is a quick-start example using `BertTokenizer`, `BertModel` and `BertForMaskedLM` class with Google AI's pre-trained `Bert base uncased` model. See the [doc section](#doc) below for all the details on these classes.
thomwolf's avatar
thomwolf committed
153

thomwolf's avatar
thomwolf committed
154
First let's prepare a tokenized input with `BertTokenizer`
thomwolf's avatar
thomwolf committed
155
156
157

```python
import torch
thomwolf's avatar
thomwolf committed
158
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
thomwolf's avatar
thomwolf committed
159

thomwolf's avatar
thomwolf committed
160
# Load pre-trained model tokenizer (vocabulary)
thomwolf's avatar
thomwolf committed
161
162
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

thomwolf's avatar
thomwolf committed
163
# Tokenized input
thomwolf's avatar
thomwolf committed
164
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
thomwolf's avatar
thomwolf committed
165
tokenized_text = tokenizer.tokenize(text)
thomwolf's avatar
thomwolf committed
166
167

# Mask a token that we will try to predict back with `BertForMaskedLM`
thomwolf's avatar
thomwolf committed
168
169
masked_index = 6
tokenized_text[masked_index] = '[MASK]'
thomwolf's avatar
thomwolf committed
170
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
thomwolf's avatar
thomwolf committed
171
172
173

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
thomwolf's avatar
thomwolf committed
174
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
thomwolf's avatar
thomwolf committed
175
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
thomwolf's avatar
thomwolf committed
176

thomwolf's avatar
thomwolf committed
177
# Convert inputs to PyTorch tensors
thomwolf's avatar
thomwolf committed
178
179
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
thomwolf's avatar
thomwolf committed
180
181
182
183
184
185
186
```

Let's see how to use `BertModel` to get hidden states

```python
# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
187
model.eval()
thomwolf's avatar
thomwolf committed
188

thomwolf's avatar
thomwolf committed
189
190
191
192
193
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

thomwolf's avatar
thomwolf committed
194
# Predict hidden states features for each layer
thomwolf's avatar
thomwolf committed
195
196
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
thomwolf's avatar
thomwolf committed
197
198
199
200
201
202
203
204
205
206
207
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12
```

And how to use `BertForMaskedLM`

```python
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

thomwolf's avatar
thomwolf committed
208
209
210
211
212
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

thomwolf's avatar
thomwolf committed
213
# Predict all tokens
thomwolf's avatar
thomwolf committed
214
215
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)
thomwolf's avatar
thomwolf committed
216

thomwolf's avatar
thomwolf committed
217
# confirm we were able to predict 'henson'
thomwolf's avatar
thomwolf committed
218
predicted_index = torch.argmax(predictions[0, masked_index]).item()
219
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
thomwolf's avatar
thomwolf committed
220
221
222
assert predicted_token == 'henson'
```

thomwolf's avatar
thomwolf committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
### OpenAI GPT

Here is a quick-start example using `OpenAIGPTTokenizer`, `OpenAIGPTModel` and `OpenAIGPTLMHeadModel` class with OpenAI's pre-trained  model. See the [doc section](#doc) below for all the details on these classes.

First let's prepare a tokenized input with `OpenAIGPTTokenizer`

```python
import torch
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel

# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')

# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
```

Let's see how to use `OpenAIGPTModel` to get hidden states
thomwolf's avatar
thomwolf committed
248
249
250
251
252
253

```python
# Load pre-trained model (weights)
model = OpenAIGPTModel.from_pretrained('openai-gpt')
model.eval()

thomwolf's avatar
thomwolf committed
254
255
256
257
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda')

thomwolf's avatar
thomwolf committed
258
# Predict hidden states features for each layer
thomwolf's avatar
thomwolf committed
259
260
with torch.no_grad():
    hidden_states = model(tokens_tensor)
thomwolf's avatar
thomwolf committed
261
262
263
264
265
266
267
268
269
```

And how to use `OpenAIGPTLMHeadModel`

```python
# Load pre-trained model (weights)
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
model.eval()

thomwolf's avatar
thomwolf committed
270
271
272
273
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
model.to('cuda')

thomwolf's avatar
thomwolf committed
274
# Predict all tokens
thomwolf's avatar
thomwolf committed
275
276
with torch.no_grad():
    predictions = model(tokens_tensor)
thomwolf's avatar
thomwolf committed
277
278

# get the predicted last token
thomwolf's avatar
thomwolf committed
279
predicted_index = torch.argmax(predictions[0, -1, :]).item()
thomwolf's avatar
thomwolf committed
280
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
thomwolf's avatar
thomwolf committed
281
assert predicted_token == '.</w>'
thomwolf's avatar
thomwolf committed
282
283
284
285
```

### Transformer-XL

thomwolf's avatar
thomwolf committed
286
Here is a quick-start example using `TransfoXLTokenizer`, `TransfoXLModel` and `TransfoXLModelLMHeadModel` class with the Transformer-XL model pre-trained on WikiText-103. See the [doc section](#doc) below for all the details on these classes.
thomwolf's avatar
thomwolf committed
287

thomwolf's avatar
thomwolf committed
288
First let's prepare a tokenized input with `TransfoXLTokenizer`
thomwolf's avatar
thomwolf committed
289
290
291

```python
import torch
thomwolf's avatar
thomwolf committed
292
from pytorch_pretrained_bert import TransfoXLTokenizer, TransfoXLModel, TransfoXLLMHeadModel
thomwolf's avatar
thomwolf committed
293

thomwolf's avatar
thomwolf committed
294
295
# Load pre-trained model tokenizer (vocabulary from wikitext 103)
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
thomwolf's avatar
thomwolf committed
296
297

# Tokenized input
thomwolf's avatar
thomwolf committed
298
299
300
301
text_1 = "Who was Jim Henson ?"
text_2 = "Jim Henson was a puppeteer"
tokenized_text_1 = tokenizer.tokenize(text_1)
tokenized_text_2 = tokenizer.tokenize(text_2)
thomwolf's avatar
thomwolf committed
302
303

# Convert token to vocabulary indices
thomwolf's avatar
thomwolf committed
304
305
indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)
thomwolf's avatar
thomwolf committed
306
307

# Convert inputs to PyTorch tensors
thomwolf's avatar
thomwolf committed
308
309
tokens_tensor_1 = torch.tensor([indexed_tokens_1])
tokens_tensor_2 = torch.tensor([indexed_tokens_2])
thomwolf's avatar
thomwolf committed
310
311
```

thomwolf's avatar
thomwolf committed
312
Let's see how to use `TransfoXLModel` to get hidden states
thomwolf's avatar
thomwolf committed
313
314
315

```python
# Load pre-trained model (weights)
thomwolf's avatar
thomwolf committed
316
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
thomwolf's avatar
thomwolf committed
317
318
model.eval()

thomwolf's avatar
thomwolf committed
319
320
321
322
323
324
325
326
327
328
# If you have a GPU, put everything on cuda
tokens_tensor_1 = tokens_tensor_1.to('cuda')
tokens_tensor_2 = tokens_tensor_2.to('cuda')
model.to('cuda')

with torch.no_grad():
    # Predict hidden states features for each layer
    hidden_states_1, mems_1 = model(tokens_tensor_1)
    # We can re-use the memory cells in a subsequent call to attend a longer context
    hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
thomwolf's avatar
thomwolf committed
329
330
```

thomwolf's avatar
thomwolf committed
331
And how to use `TransfoXLLMHeadModel`
thomwolf's avatar
thomwolf committed
332
333
334

```python
# Load pre-trained model (weights)
thomwolf's avatar
thomwolf committed
335
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
thomwolf's avatar
thomwolf committed
336
337
model.eval()

thomwolf's avatar
thomwolf committed
338
339
340
341
342
343
344
345
346
347
# If you have a GPU, put everything on cuda
tokens_tensor_1 = tokens_tensor_1.to('cuda')
tokens_tensor_2 = tokens_tensor_2.to('cuda')
model.to('cuda')

with torch.no_grad():
    # Predict all tokens
    predictions_1, mems_1 = model(tokens_tensor_1)
    # We can re-use the memory cells in a subsequent call to attend a longer context
    predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
thomwolf's avatar
thomwolf committed
348
349

# get the predicted last token
thomwolf's avatar
thomwolf committed
350
predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
thomwolf's avatar
thomwolf committed
351
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
thomwolf's avatar
thomwolf committed
352
assert predicted_token == 'who'
thomwolf's avatar
thomwolf committed
353
354
```

thomwolf's avatar
thomwolf committed
355
## Doc
thomwolf's avatar
thomwolf committed
356

thomwolf's avatar
thomwolf committed
357
358
359
360
Here is a detailed documentation of the classes in the package and how to use them:

| Sub-section | Description |
|-|-|
thomwolf's avatar
thomwolf committed
361
| [Loading Google AI's/OpenAI's pre-trained weigths](#Loading-Google-AI-or-OpenAI-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI/OpenAI's pre-trained weight or a PyTorch saved instance |
362
| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` |
thomwolf's avatar
thomwolf committed
363
| [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class|
thomwolf's avatar
thomwolf committed
364
| [Optimizer: `BertAdam`](#Optimizer-BertAdam) |  API of the `BertAdam` class |
thomwolf's avatar
thomwolf committed
365

thomwolf's avatar
thomwolf committed
366
### Loading Google AI or OpenAI pre-trained weigths or PyTorch dump
thomwolf's avatar
thomwolf committed
367

thomwolf's avatar
thomwolf committed
368
To load one of Google AI's, OpenAI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
thomwolf's avatar
thomwolf committed
369
370

```python
Thomas Wolf's avatar
Thomas Wolf committed
371
model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
thomwolf's avatar
thomwolf committed
372
373
374
375
```

where

thomwolf's avatar
thomwolf committed
376
- `BERT_CLASS` is either a tokenizer to load the vocabulary (`BertTokenizer` or `OpenAIGPTTokenizer` classes) or one of the eight BERT or three OpenAI GPT PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice`, `BertForQuestionAnswering`, `OpenAIGPTModel`, `OpenAIGPTLMHeadModel` or `OpenAIGPTDoubleHeadsModel`, and
Thomas Wolf's avatar
Thomas Wolf committed
377
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
thomwolf's avatar
thomwolf committed
378

thomwolf's avatar
thomwolf committed
379
  - the shortcut name of a Google AI's or OpenAI's pre-trained model selected in the list:
thomwolf's avatar
thomwolf committed
380

thomwolf's avatar
thomwolf committed
381
382
383
    - `bert-base-uncased`: 12-layer, 768-hidden, 12-heads, 110M parameters
    - `bert-large-uncased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
    - `bert-base-cased`: 12-layer, 768-hidden, 12-heads , 110M parameters
thomwolf's avatar
thomwolf committed
384
385
    - `bert-large-cased`: 24-layer, 1024-hidden, 16-heads, 340M parameters
    - `bert-base-multilingual-uncased`: (Orig, not recommended) 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
thomwolf's avatar
thomwolf committed
386
    - `bert-base-multilingual-cased`: **(New, recommended)** 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
thomwolf's avatar
thomwolf committed
387
    - `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
thomwolf's avatar
thomwolf committed
388
    - `openai-gpt`: OpenAI English model, 12-layer, 768-hidden, 12-heads, 110M parameters
thomwolf's avatar
thomwolf committed
389
    - `transfo-xl-wt103`: Transformer-XL English model trained on wikitext-103, 18-layer, 1024-hidden, 16-heads, 257M parameters
thomwolf's avatar
thomwolf committed
390

thomwolf's avatar
thomwolf committed
391
  - a path or url to a pretrained model archive containing:
thomwolf's avatar
thomwolf committed
392

thomwolf's avatar
thomwolf committed
393
    - `bert_config.json` or `openai_gpt_config.json` a configuration file for the model, and
thomwolf's avatar
thomwolf committed
394
    - `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BertForPreTraining`, `OpenAIGPTModel` or `TransfoXLModel` (saved with the usual `torch.save()`)
thomwolf's avatar
thomwolf committed
395

396
  If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
Thomas Wolf's avatar
Thomas Wolf committed
397
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information).
thomwolf's avatar
thomwolf committed
398

399
400
`Uncased` means that the text has been lowercased before WordPiece tokenization, e.g., `John Smith` becomes `john smith`. The Uncased model also strips out any accent markers. `Cased` means that the true case and accent markers are preserved. Typically, the Uncased model is better unless you know that case information is important for your task (e.g., Named Entity Recognition or Part-of-Speech tagging). For information about the Multilingual and Chinese model, see the [Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md) or the original TensorFlow repository.

Thomas Wolf's avatar
Thomas Wolf committed
401
**When using an `uncased model`, make sure to pass `--do_lower_case` to the example training scripts (or pass `do_lower_case=True` to FullTokenizer if you're using your own script and loading the tokenizer your-self.).**
402

thomwolf's avatar
thomwolf committed
403
Examples:
thomwolf's avatar
thomwolf committed
404
```python
thomwolf's avatar
thomwolf committed
405
# BERT
Thomas Wolf's avatar
Thomas Wolf committed
406
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
thomwolf's avatar
thomwolf committed
407
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
408
409
410
411

# OpenAI GPT
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTModel.from_pretrained('openai-gpt')
thomwolf's avatar
thomwolf committed
412
413
414
415

# Transformer-XL
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
thomwolf's avatar
thomwolf committed
416
417
```

thomwolf's avatar
thomwolf committed
418
### PyTorch models
thomwolf's avatar
thomwolf committed
419

thomwolf's avatar
thomwolf committed
420
#### 1. `BertModel`
thomwolf's avatar
thomwolf committed
421

thomwolf's avatar
thomwolf committed
422
423
424
425
`BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large).

The inputs and output are **identical to the TensorFlow model inputs and outputs**.

thomwolf's avatar
thomwolf committed
426
We detail them here. This model takes as *inputs*:
427
[`modeling.py`](./pytorch_pretrained_bert/modeling.py)
428
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py)), and
Clement's avatar
typos  
Clement committed
429
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
Thomas Wolf's avatar
Thomas Wolf committed
430
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
thomwolf's avatar
thomwolf committed
431
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
thomwolf's avatar
thomwolf committed
432

thomwolf's avatar
thomwolf committed
433
This model *outputs* a tuple composed of:
thomwolf's avatar
thomwolf committed
434

thomwolf's avatar
thomwolf committed
435
436
- `encoded_layers`: controled by the value of the `output_encoded_layers` argument:

Thomas Wolf's avatar
Thomas Wolf committed
437
438
  - `output_all_encoded_layers=True`: outputs a list of the encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
  - `output_all_encoded_layers=False`: outputs only the encoded-hidden-states corresponding to the last attention block, i.e. a single torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
thomwolf's avatar
thomwolf committed
439

thomwolf's avatar
thomwolf committed
440
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
thomwolf's avatar
thomwolf committed
441

442
An example on how to use this class is given in the [`extract_features.py`](./examples/extract_features.py) script which can be used to extract the hidden states of the model for a given input.
thomwolf's avatar
thomwolf committed
443

thomwolf's avatar
thomwolf committed
444
#### 2. `BertForPreTraining`
thomwolf's avatar
thomwolf committed
445
446
447
448
449
450

`BertForPreTraining` includes the `BertModel` Transformer followed by the two pre-training heads:

- the masked language modeling head, and
- the next sentence classification head.

thomwolf's avatar
thomwolf committed
451
*Inputs* comprises the inputs of the [`BertModel`](#-1.-`BertModel`) class plus two optional labels:
thomwolf's avatar
thomwolf committed
452
453
454
455
456
457
458
459

- `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]
- `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.

*Outputs*:

- if `masked_lm_labels` and `next_sentence_label` are not `None`: Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss.
- if `masked_lm_labels` or `next_sentence_label` is `None`: Outputs a tuple comprising
Thomas Wolf's avatar
Thomas Wolf committed
460

thomwolf's avatar
thomwolf committed
461
462
  - the masked language modeling logits, and
  - the next sentence classification logits.
tholor's avatar
tholor committed
463
464
465
  
An example on how to use this class is given in the [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) script which can be used to fine-tune the BERT language model on your specific different text corpus. This should improve model performance, if the language style is different from the original BERT training corpus (Wiki + BookCorpus).

thomwolf's avatar
thomwolf committed
466

thomwolf's avatar
thomwolf committed
467
#### 3. `BertForMaskedLM`
thomwolf's avatar
thomwolf committed
468
469
470

`BertForMaskedLM` includes the `BertModel` Transformer followed by the (possibly) pre-trained  masked language modeling head.

thomwolf's avatar
thomwolf committed
471
*Inputs* comprises the inputs of the [`BertModel`](#-1.-`BertModel`) class plus optional label:
thomwolf's avatar
thomwolf committed
472
473
474
475
476
477
478
479

- `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]

*Outputs*:

- if `masked_lm_labels` is not `None`: Outputs the masked language modeling loss.
- if `masked_lm_labels` is `None`: Outputs the masked language modeling logits.

thomwolf's avatar
thomwolf committed
480
#### 4. `BertForNextSentencePrediction`
thomwolf's avatar
thomwolf committed
481
482
483

`BertForNextSentencePrediction` includes the `BertModel` Transformer followed by the next sentence classification head.

thomwolf's avatar
thomwolf committed
484
*Inputs* comprises the inputs of the [`BertModel`](#-1.-`BertModel`) class plus an optional label:
thomwolf's avatar
thomwolf committed
485
486
487
488
489
490
491
492

- `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.

*Outputs*:

- if `next_sentence_label` is not `None`: Outputs the next sentence classification loss.
- if `next_sentence_label` is `None`: Outputs the next sentence classification logits.

thomwolf's avatar
thomwolf committed
493
#### 5. `BertForSequenceClassification`
thomwolf's avatar
thomwolf committed
494

Thomas Wolf's avatar
typos  
Thomas Wolf committed
495
`BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`.
thomwolf's avatar
thomwolf committed
496

Thomas Wolf's avatar
Thomas Wolf committed
497
The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper).
thomwolf's avatar
thomwolf committed
498

499
An example on how to use this class is given in the [`run_classifier.py`](./examples/run_classifier.py) script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task.
thomwolf's avatar
thomwolf committed
500

501
502
503
504
#### 6. `BertForMultipleChoice`

`BertForMultipleChoice` is a fine-tuning model that includes `BertModel` and a linear layer on top of the `BertModel`.

Gr茅gory Ch芒tel's avatar
Gr茅gory Ch芒tel committed
505
The linear layer outputs a single value for each choice of a multiple choice problem, then all the outputs corresponding to an instance are passed through a softmax to get the model choice.
506
507
508
509
510
511

This implementation is largely inspired by the work of OpenAI in [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) and the answer of Jacob Devlin in the following [issue](https://github.com/google-research/bert/issues/38).

An example on how to use this class is given in the [`run_swag.py`](./examples/run_swag.py) script which can be used to fine-tune a multiple choice classifier using BERT, for example for the Swag task.

#### 7. `BertForTokenClassification`
512
513
514
515
516

`BertForTokenClassification` is a fine-tuning model that includes `BertModel` and a token-level classifier on top of the `BertModel`.

The token-level classifier is a linear layer that takes as input the last hidden state of the sequence.

517
#### 8. `BertForQuestionAnswering`
thomwolf's avatar
thomwolf committed
518

Knut Ole Sj酶li's avatar
Knut Ole Sj酶li committed
519
`BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states.
thomwolf's avatar
thomwolf committed
520

Thomas Wolf's avatar
Thomas Wolf committed
521
The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper).
thomwolf's avatar
thomwolf committed
522

523
An example on how to use this class is given in the [`run_squad.py`](./examples/run_squad.py) script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
thomwolf's avatar
thomwolf committed
524

thomwolf's avatar
thomwolf committed
525
526
527
528
#### 9. `OpenAIGPTModel`

`OpenAIGPTModel` is the basic OpenAI GPT Transformer model with a layer of summed token and position embeddings followed by a series of 12 identical self-attention blocks.

529
530
531
532
533
534
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: `[SEP]`, `[CLS]`...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.

The embeddings are ordered as follow in the token embeddings matrice:
thomwolf's avatar
thomwolf committed
535

536
```python
thomwolf's avatar
thomwolf committed
537
538
539
540
541
    [0,                                                         ----------------------
      ...                                                        -> word embeddings
      config.vocab_size - 1,                                     ______________________
      config.vocab_size,
      ...                                                        -> special embeddings
542
543
      config.vocab_size + config.n_special - 1]                  ______________________
```
thomwolf's avatar
thomwolf committed
544

545
546
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
    `total_tokens_embeddings = config.vocab_size + config.n_special`
thomwolf's avatar
thomwolf committed
547
548
549
550
551
552
You should use the associate indices to index the embeddings.

The inputs and output are **identical to the TensorFlow model inputs and outputs**.

We detail them here. This model takes as *inputs*:
[`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py)
553
554
555
556
557
558
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
- `position_ids`: an optional torch.LongTensor with the same shape as input_ids
    with the position indices (selected in the range [0, config.n_positions - 1[.
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
    You can use it to add a third type of embedding to each input token in the sequence
    (the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
thomwolf's avatar
thomwolf committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

This model *outputs*:
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)

#### 10. `OpenAIGPTLMHeadModel`

`OpenAIGPTLMHeadModel` includes the `OpenAIGPTModel` Transformer followed by a language modeling head with weights tied to the input embeddings (no additional parameters).

*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus optional labels:
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].

*Outputs*:
- if `lm_labels` is not `None`:
  Outputs the language modeling loss.
- else:
574
  Outputs `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_tokens_embeddings] (or more generally [d_1, ..., d_n, total_tokens_embeddings] were d_1 ... d_n are the dimension of input_ids)
thomwolf's avatar
thomwolf committed
575
576
577
578
579

#### 11. `OpenAIGPTDoubleHeadsModel`

`OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads:
- a language modeling head with weights tied to the input embeddings (no additional parameters) and:
580
- a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper).
thomwolf's avatar
thomwolf committed
581
582

*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels:
583
- `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice).
thomwolf's avatar
thomwolf committed
584
585
586
587
588
589
590
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].

*Outputs*:
- if `lm_labels` and `multiple_choice_labels` are not `None`:
  Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
- else Outputs a tuple with:
591
  - `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_tokens_embeddings]
thomwolf's avatar
thomwolf committed
592
593
  - `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]

thomwolf's avatar
thomwolf committed
594
595
596
597
598
599
600
601
602
603
604
#### 12. `TransfoXLModel`

The Transformer-XL model is described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context".

Transformer XL use a relative positioning with sinusiodal patterns and adaptive softmax inputs which means that:

- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.

This model takes as *inputs*:
[`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py)
605
606
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the token indices selected in the range [0, self.config.n_token[
- `mems`: an optional memory of hidden states from previous forward passes as a list (num layers) of hidden states at the entry of each layer. Each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
thomwolf's avatar
thomwolf committed
607
608

This model *outputs* a tuple of (last_hidden_state, new_mems)
609
610
- `last_hidden_state`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
thomwolf's avatar
thomwolf committed
611
612
613
614
615
616

#### 13. `TransfoXLLMHeadModel`

`TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings.

*Inputs* are the same as the inputs of the [`TransfoXLModel`](#-12.-`TransfoXLModel`) class plus optional labels:
617
- `target`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the target token indices selected in the range [0, self.config.n_token[
thomwolf's avatar
thomwolf committed
618
619
620

*Outputs* a tuple of (last_hidden_state, new_mems)
- `softmax_output`: output of the (adaptive) softmax:
621
622
623
  - if target is None: Negative log likelihood of shape [batch_size, sequence_length]
  - else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
thomwolf's avatar
thomwolf committed
624

thomwolf's avatar
thomwolf committed
625
626
627
628

### Tokenizers:

#### `BertTokenizer`
thomwolf's avatar
thomwolf committed
629

thomwolf's avatar
thomwolf committed
630
`BertTokenizer` perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
thomwolf's avatar
thomwolf committed
631

thomwolf's avatar
thomwolf committed
632
This class has two arguments:
thomwolf's avatar
thomwolf committed
633

thomwolf's avatar
thomwolf committed
634
635
- `vocab_file`: path to a vocabulary file.
- `do_lower_case`: convert text to lower-case while tokenizing. **Default = True**.
thomwolf's avatar
thomwolf committed
636

thomwolf's avatar
thomwolf committed
637
and three methods:
Thomas Wolf's avatar
typos  
Thomas Wolf committed
638

thomwolf's avatar
thomwolf committed
639
640
641
- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.
thomwolf's avatar
thomwolf committed
642

thomwolf's avatar
thomwolf committed
643
Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) for the details of the `BasicTokenizer` and `WordpieceTokenizer` classes. In general it is recommended to use `BertTokenizer` unless you know what you are doing.
thomwolf's avatar
thomwolf committed
644

thomwolf's avatar
thomwolf committed
645
646
647
648
#### `OpenAIGPTTokenizer`

`OpenAIGPTTokenizer` perform Byte-Pair-Encoding (BPE) tokenization.

thomwolf's avatar
thomwolf committed
649
This class has two arguments:
thomwolf's avatar
thomwolf committed
650
651
652
653
654
655
656
657
658
659
660
661

- `vocab_file`: path to a vocabulary file.
- `merges_file`: path to a file containing the BPE merges.

and three methods:

- `tokenize(text)`: convert a `str` in a list of `str` tokens by (1) performing basic tokenization and (2) WordPiece tokenization.
- `convert_tokens_to_ids(tokens)`: convert a list of `str` tokens in a list of `int` indices in the vocabulary.
- `convert_ids_to_tokens(tokens)`: convert a list of `int` indices in a list of `str` tokens in the vocabulary.

Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch_pretrained_bert/tokenization_openai.py) for the details of the `OpenAIGPTTokenizer`.

thomwolf's avatar
thomwolf committed
662
663
#### `TransfoXLTokenizer`

664
`TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details.
thomwolf's avatar
thomwolf committed
665

666
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
thomwolf's avatar
thomwolf committed
667

thomwolf's avatar
thomwolf committed
668
669
670
### Optimizers:

#### `BertAdam`
thomwolf's avatar
thomwolf committed
671

thomwolf's avatar
thomwolf committed
672
`BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
thomwolf's avatar
thomwolf committed
673

thomwolf's avatar
thomwolf committed
674
675
- BertAdam implements weight decay fix,
- BertAdam doesn't compensate for bias as in the regular Adam optimizer.
thomwolf's avatar
thomwolf committed
676
677
678
679

The optimizer accepts the following arguments:

- `lr` : learning rate
Thomas Wolf's avatar
Thomas Wolf committed
680
- `warmup` : portion of `t_total` for the warmup, `-1`  means no warmup. Default : `-1`
thomwolf's avatar
thomwolf committed
681
- `t_total` : total number of training steps for the learning
Thomas Wolf's avatar
Thomas Wolf committed
682
683
684
685
686
    rate schedule, `-1`  means constant learning rate. Default : `-1`
- `schedule` : schedule to use for the warmup (see above). Default : `'warmup_linear'`
- `b1` : Adams b1. Default : `0.9`
- `b2` : Adams b2. Default : `0.999`
- `e` : Adams epsilon. Default : `1e-6`
687
- `weight_decay:` Weight decay. Default : `0.01`
Thomas Wolf's avatar
Thomas Wolf committed
688
- `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0`
thomwolf's avatar
thomwolf committed
689

thomwolf's avatar
thomwolf committed
690
691
692
693
694
695
696
#### `OpenAIGPTAdam`

`OpenAIGPTAdam` is similar to `BertAdam`.
The differences with `BertAdam` is that `OpenAIGPTAdam` compensate for bias as in the regular Adam optimizer.

`OpenAIGPTAdam` accepts the same arguments as `BertAdam`.

thomwolf's avatar
thomwolf committed
697
## Examples
thomwolf's avatar
thomwolf committed
698

thomwolf's avatar
thomwolf committed
699
700
701
| Sub-section | Description |
|-|-|
| [Training large models: introduction, tools and examples](#Training-large-models-introduction,-tools-and-examples) | How to use gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training to train Bert models |
tholor's avatar
tholor committed
702
| [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_classifier.py`, `run_squad.py` and `run_lm_finetuning.py` |
thomwolf's avatar
thomwolf committed
703
704
| [Fine-tuning BERT-large on GPUs](#Fine-tuning-BERT-large-on-GPUs) | How to fine tune `BERT large`|

thomwolf's avatar
thomwolf committed
705
### Training large models: introduction, tools and examples
thomwolf's avatar
thomwolf committed
706

Thomas Wolf's avatar
Thomas Wolf committed
707
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
thomwolf's avatar
thomwolf committed
708

709
To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month.
thomwolf's avatar
thomwolf committed
710

thomwolf's avatar
thomwolf committed
711
Here is how to use these techniques in our scripts:
thomwolf's avatar
thomwolf committed
712

thomwolf's avatar
thomwolf committed
713
714
- **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps.
- **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs.
thomwolf's avatar
thomwolf committed
715
- **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument (see below).
Julien Chaumond's avatar
Julien Chaumond committed
716
- **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scale` flag (see the previously linked documentation for details on loss scaling). The loss scale can be zero in which case the scale is dynamically adjusted or a positive power of two in which case the scaling is static.
717

Julien Chaumond's avatar
Julien Chaumond committed
718
To use 16-bits training and distributed training, you need to install NVIDIA's apex extension [as detailed here](https://github.com/nvidia/apex). You will find more information regarding the internals of `apex` and how to use `apex` in [the doc and the associated repository](https://github.com/nvidia/apex). The results of the tests performed on pytorch-BERT by the NVIDIA team (and my trials at reproducing them) can be consulted in [the relevant PR of the present repository](https://github.com/huggingface/pytorch-pretrained-BERT/pull/116).
thomwolf's avatar
thomwolf committed
719

thomwolf's avatar
thomwolf committed
720
Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details):
thomwolf's avatar
thomwolf committed
721
722
723
```bash
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script)
```
724
Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP address `192.168.1.1` and an open port `1234`.
thomwolf's avatar
thomwolf committed
725

thomwolf's avatar
thomwolf committed
726
### Fine-tuning with BERT: running the examples
VictorSanh's avatar
VictorSanh committed
727

728
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
VictorSanh's avatar
VictorSanh committed
729

thomwolf's avatar
thomwolf committed
730
731
732
- a *sequence-level classifier* on the MRPC classification corpus,
- a *token-level classifier* on the question answering dataset SQuAD, and
- a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
tholor's avatar
tholor committed
733
734
- a *BERT language model* on another target corpus
 
735
736
737
738
739
740
#### MRPC

This example code fine-tunes BERT on the Microsoft Research Paraphrase
Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed.

Before running this example you should download the
VictorSanh's avatar
VictorSanh committed
741
742
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
743
and unpack it to some directory `$GLUE_DIR`.
VictorSanh's avatar
VictorSanh committed
744
745
746
747

```shell
export GLUE_DIR=/path/to/glue

748
python run_classifier.py \
VictorSanh's avatar
VictorSanh committed
749
750
751
  --task_name MRPC \
  --do_train \
  --do_eval \
752
  --do_lower_case \
VictorSanh's avatar
VictorSanh committed
753
  --data_dir $GLUE_DIR/MRPC/ \
thomwolf's avatar
thomwolf committed
754
  --bert_model bert-base-uncased \
VictorSanh's avatar
VictorSanh committed
755
756
757
758
  --max_seq_length 128 \
  --train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
759
  --output_dir /tmp/mrpc_output/
VictorSanh's avatar
VictorSanh committed
760
761
```

Thomas Wolf's avatar
Thomas Wolf committed
762
Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%.
thomwolf's avatar
thomwolf committed
763

764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
**Fast run with apex and 16 bit precision: fine-tuning on MRPC in 27 seconds!**
First install apex as indicated [here](https://github.com/NVIDIA/apex).
Then run
```shell
export GLUE_DIR=/path/to/glue

python run_classifier.py \
  --task_name MRPC \
  --do_train \
  --do_eval \
  --do_lower_case \
  --data_dir $GLUE_DIR/MRPC/ \
  --bert_model bert-base-uncased \
  --max_seq_length 128 \
  --train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
781
782
  --output_dir /tmp/mrpc_output/ \
  --fp16
783
784
785
786
```

#### SQuAD

thomwolf's avatar
thomwolf committed
787
This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on a single tesla V100 16GB.
VictorSanh's avatar
VictorSanh committed
788

VictorSanh's avatar
VictorSanh committed
789
The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory.
790

VictorSanh's avatar
VictorSanh committed
791
792
793
794
*   [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
*   [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
*   [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)

VictorSanh's avatar
VictorSanh committed
795
```shell
VictorSanh's avatar
VictorSanh committed
796
export SQUAD_DIR=/path/to/SQUAD
VictorSanh's avatar
VictorSanh committed
797

798
python run_squad.py \
thomwolf's avatar
thomwolf committed
799
  --bert_model bert-base-uncased \
VictorSanh's avatar
VictorSanh committed
800
801
  --do_train \
  --do_predict \
802
  --do_lower_case \
Thomas Wolf's avatar
Thomas Wolf committed
803
  --train_file $SQUAD_DIR/train-v1.1.json \
thomwolf's avatar
thomwolf committed
804
805
  --predict_file $SQUAD_DIR/dev-v1.1.json \
  --train_batch_size 12 \
Thomas Wolf's avatar
Thomas Wolf committed
806
  --learning_rate 3e-5 \
thomwolf's avatar
thomwolf committed
807
808
809
  --num_train_epochs 2.0 \
  --max_seq_length 384 \
  --doc_stride 128 \
thomwolf's avatar
thomwolf committed
810
  --output_dir /tmp/debug_squad/
thomwolf's avatar
thomwolf committed
811
```
812

Thomas Wolf's avatar
Thomas Wolf committed
813
Training with the previous hyper-parameters gave us the following results:
814
```bash
Thomas Wolf's avatar
Thomas Wolf committed
815
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
816
```
817

thomwolf's avatar
thomwolf committed
818
819
820
#### SWAG

The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)
821
822
823
824
825
826
827

```shell
export SWAG_DIR=/path/to/SWAG

python run_swag.py \
  --bert_model bert-base-uncased \
  --do_train \
thomwolf's avatar
thomwolf committed
828
  --do_lower_case \
829
  --do_eval \
thomwolf's avatar
thomwolf committed
830
  --data_dir $SWAG_DIR/data \
831
  --train_batch_size 16 \
832
833
834
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
  --max_seq_length 80 \
thomwolf's avatar
thomwolf committed
835
  --output_dir /tmp/swag_output/ \
836
  --gradient_accumulation_steps 4
837
838
```

839
Training with the previous hyper-parameters on a single GPU gave us the following results:
840
```
841
842
843
844
eval_accuracy = 0.8062081375587323
eval_loss = 0.5966546792367169
global_step = 13788
loss = 0.06423990014260186
845
846
```

tholor's avatar
tholor committed
847
848
849
#### LM Fine-tuning

The data should be a text file in the same format as [sample_text.txt](./samples/sample_text.txt)  (one sentence per line, docs separated by empty line).
850
851
You can download an [exemplary training corpus](https://ext-bert-sample.obs.eu-de.otc.t-systems.com/small_wiki_sentence_corpus.txt) generated from wikipedia articles and splitted into ~500k sentences with spaCy. 
Training one epoch on this corpus takes about 1:20h on 4 x NVIDIA Tesla P100 with `train_batch_size=200` and `max_seq_length=128`:
tholor's avatar
tholor committed
852
853
854
855


```shell
python run_lm_finetuning.py \
856
857
858
859
860
861
862
863
864
  --bert_model bert-base-uncased \
  --do_lower_case \
  --do_train \
  --train_file ../samples/sample_text.txt \
  --output_dir models \
  --num_train_epochs 5.0 \
  --learning_rate 3e-5 \
  --train_batch_size 32 \
  --max_seq_length 128 \
tholor's avatar
tholor committed
865
866
```

thomwolf's avatar
thomwolf committed
867
868
### OpenAI GPT and Transformer-XL: running the examples

thomwolf's avatar
thomwolf committed
869
870
871
872
We provide two examples of scripts for OpenAI GPT and Transformer-XL based on (and extended from) the respective original implementations:

- fine-tuning OpenAI GPT on the ROCStories dataset
- evaluating Transformer-XL on Wikitext 103
thomwolf's avatar
thomwolf committed
873
874
875
876
877
878
879
880
881
882
883

#### Fine-tuning OpenAI GPT on the RocStories dataset

This example code fine-tunes OpenAI GPT on the RocStories dataset.

Before running this example you should download the
[RocStories dataset](https://github.com/snigdhac/StoryComprehension_EMNLP/tree/master/Dataset/RoCStories) and unpack it to some directory `$ROC_STORIES_DIR`.

```shell
export ROC_STORIES_DIR=/path/to/RocStories

thomwolf's avatar
thomwolf committed
884
885
python run_openai_gpt.py \
  --model_name openai-gpt \
thomwolf's avatar
thomwolf committed
886
887
  --do_train \
  --do_eval \
thomwolf's avatar
thomwolf committed
888
889
890
891
  --train_dataset $ROC_STORIES_DIR/cloze_test_val__spring2016\ -\ cloze_test_ALL_val.csv \
  --eval_dataset $ROC_STORIES_DIR/cloze_test_test__spring2016\ -\ cloze_test_ALL_test.csv \
  --output_dir ../log \
  --train_batch_size 16 \
thomwolf's avatar
thomwolf committed
892
893
```

894
This command runs in about 10 min on a single K-80 an gives an evaluation accuracy of about 86.4% (the authors report a median accuracy with the TensorFlow code of 85.8% and the OpenAI GPT paper reports a best single run accuracy of 86.5%).
thomwolf's avatar
thomwolf committed
895
896
897
898
899
900
901
902
903
904

#### Evaluating the pre-trained Transformer-XL on the WikiText 103 dataset

This example code evaluate the pre-trained Transformer-XL on the WikiText 103 dataset.
This command will download a pre-processed version of the WikiText 103 dataset in which the vocabulary has been computed.

```shell
python run_transfo_xl.py --work_dir ../log
```

905
This command runs in about 1 min on a V100 and gives an evaluation perplexity of 18.22 on WikiText-103 (the authors report a perplexity of about 18.3 on this dataset with the TensorFlow code).
thomwolf's avatar
thomwolf committed
906

thomwolf's avatar
thomwolf committed
907
## Fine-tuning BERT-large on GPUs
908
909
910

The options we list above allow to fine-tune BERT-large rather easily on GPU(s) instead of the TPU used by the original implementation.

Thomas Wolf's avatar
Thomas Wolf committed
911
For example, fine-tuning BERT-large on SQuAD can be done on a server with 4 k-80 (these are pretty old now) in 18 hours. Our results are similar to the TensorFlow implementation results (actually slightly higher):
912
913
914
```bash
{"exact_match": 84.56953642384106, "f1": 91.04028647786927}
```
Thomas Wolf's avatar
Thomas Wolf committed
915
To get these results we used a combination of:
916
917
918
919
- multi-GPU training (automatically activated on a multi-GPU server),
- 2 steps of gradient accumulation and
- perform the optimization step on CPU to store Adam's averages in RAM.

thomwolf's avatar
thomwolf committed
920
Here is the full list of hyper-parameters for this run:
921
```bash
Thomas Wolf's avatar
Thomas Wolf committed
922
python ./run_squad.py \
thomwolf's avatar
thomwolf committed
923
  --bert_model bert-large-uncased \
Thomas Wolf's avatar
Thomas Wolf committed
924
925
  --do_train \
  --do_predict \
926
  --do_lower_case \
Thomas Wolf's avatar
Thomas Wolf committed
927
928
929
930
931
932
933
934
  --train_file $SQUAD_TRAIN \
  --predict_file $SQUAD_EVAL \
  --learning_rate 3e-5 \
  --num_train_epochs 2 \
  --max_seq_length 384 \
  --doc_stride 128 \
  --output_dir $OUTPUT_DIR \
  --train_batch_size 24 \
Daniel Khashabi's avatar
Daniel Khashabi committed
935
  --gradient_accumulation_steps 2 
936
```
937
938
939
940
941
942

If you have a recent GPU (starting from NVIDIA Volta series), you should try **16-bit fine-tuning** (FP16).

Here is an example of hyper-parameters for a FP16 run we tried:
```bash
python ./run_squad.py \
thomwolf's avatar
thomwolf committed
943
  --bert_model bert-large-uncased \
944
945
  --do_train \
  --do_predict \
946
  --do_lower_case \
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
  --train_file $SQUAD_TRAIN \
  --predict_file $SQUAD_EVAL \
  --learning_rate 3e-5 \
  --num_train_epochs 2 \
  --max_seq_length 384 \
  --doc_stride 128 \
  --output_dir $OUTPUT_DIR \
  --train_batch_size 24 \
  --fp16 \
  --loss_scale 128
```

The results were similar to the above FP32 results (actually slightly higher):
```bash
{"exact_match": 84.65468306527909, "f1": 91.238669287002}
```
thomwolf's avatar
thomwolf committed
963

thomwolf's avatar
thomwolf committed
964
## Notebooks
thomwolf's avatar
thomwolf committed
965

Thomas Wolf's avatar
Thomas Wolf committed
966
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
thomwolf's avatar
thomwolf committed
967

thomwolf's avatar
thomwolf committed
968
969
970
- The first NoteBook ([Comparing-TF-and-PT-models.ipynb](./notebooks/Comparing-TF-and-PT-models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models.

- The second NoteBook ([Comparing-TF-and-PT-models-SQuAD.ipynb](./notebooks/Comparing-TF-and-PT-models-SQuAD.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models.
thomwolf's avatar
thomwolf committed
971

Thomas Wolf's avatar
Thomas Wolf committed
972
- The third NoteBook ([Comparing-TF-and-PT-models-MLM-NSP.ipynb](./notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb)) compares the predictions computed by the TensorFlow and the PyTorch models for masked token language modeling using the pre-trained masked language modeling model.
thomwolf's avatar
thomwolf committed
973

thomwolf's avatar
thomwolf committed
974
Please follow the instructions given in the notebooks to run and modify them.
thomwolf's avatar
thomwolf committed
975

thomwolf's avatar
thomwolf committed
976
## Command-line interface
thomwolf's avatar
thomwolf committed
977

thomwolf's avatar
thomwolf committed
978
979
980
A command-line interface is provided to convert a TensorFlow checkpoint in a PyTorch dump of the `BertForPreTraining` class  (for BERT) or NumPy checkpoint in a PyTorch dump of the `OpenAIGPTModel` class  (for OpenAI GPT).

### BERT
thomwolf's avatar
thomwolf committed
981

Thomas Wolf's avatar
Thomas Wolf committed
982
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
thomwolf's avatar
thomwolf committed
983

984
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in [`extract_features.py`](./examples/extract_features.py), [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`]((./examples/run_squad.py))).
thomwolf's avatar
thomwolf committed
985
986
987
988
989
990
991
992
993
994

You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.

To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch.

Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:

```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

thomwolf's avatar
thomwolf committed
995
pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \
thomwolf's avatar
thomwolf committed
996
997
998
  $BERT_BASE_DIR/bert_model.ckpt \
  $BERT_BASE_DIR/bert_config.json \
  $BERT_BASE_DIR/pytorch_model.bin
thomwolf's avatar
thomwolf committed
999
1000
1001
1002
```

You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).

thomwolf's avatar
thomwolf committed
1003
1004
### OpenAI GPT

thomwolf's avatar
thomwolf committed
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
Here is an example of the conversion process for a pre-trained OpenAI GPT model, assuming that your NumPy checkpoint save as the same format than OpenAI pretrained model (see [here](https://github.com/openai/finetune-transformer-lm))

```shell
export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights

pytorch_pretrained_bert convert_openai_checkpoint \
  $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
  $PYTORCH_DUMP_OUTPUT \
  [OPENAI_GPT_CONFIG]
```

### Transformer-XL

Here is an example of the conversion process for a pre-trained Transformer-XL model (see [here](https://github.com/kimiyoung/transformer-xl/tree/master/tf#obtain-and-evaluate-pretrained-sota-models))
thomwolf's avatar
thomwolf committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028

```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

pytorch_pretrained_bert convert_openai_checkpoint \
  $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
  $PYTORCH_DUMP_OUTPUT \
  [OPENAI_GPT_CONFIG]
```

thomwolf's avatar
thomwolf committed
1029
## TPU
thomwolf's avatar
thomwolf committed
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039

TPU support and pretraining scripts

TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)).

We will add TPU support when this next release is published.

The original TensorFlow code further comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py).

Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts.