training.mdx 14.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Fine-tuning a pretrained model

15
16
[[open-in-colab]]

Sylvain Gugger's avatar
Sylvain Gugger committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
In this tutorial, we will show you how to fine-tune a pretrained model from the Transformers library. In TensorFlow,
models can be directly trained using Keras and the `fit` method. In PyTorch, there is no generic training loop so
the 馃 Transformers library provides an API with the class [`Trainer`] to let you fine-tune or train
a model from scratch easily. Then we will show you how to alternatively write the whole training loop in PyTorch.

Before we can fine-tune a model, we need a dataset. In this tutorial, we will show you how to fine-tune BERT on the
[IMDB dataset](https://www.imdb.com/interfaces/): the task is to classify whether movie reviews are positive or
negative. For examples of other tasks, refer to the [additional-resources](#additional-resources) section!

<a id='data-processing'></a>

## Preparing the datasets

<Youtube id="_BZearw7f0w"/>

We will use the [馃 Datasets](https://github.com/huggingface/datasets/) library to download and preprocess the IMDB
datasets. We will go over this part pretty quickly. Since the focus of this tutorial is on training, you should refer
to the 馃 Datasets [documentation](https://huggingface.co/docs/datasets/) or the [preprocessing](preprocessing) tutorial for
more information.

First, we can use the `load_dataset` function to download and cache the dataset:

```python
from datasets import load_dataset

raw_datasets = load_dataset("imdb")
```

This works like the `from_pretrained` method we saw for the models and tokenizers (except the cache directory is
_~/.cache/huggingface/dataset_ by default).

The `raw_datasets` object is a dictionary with three keys: `"train"`, `"test"` and `"unsupervised"`
(which correspond to the three splits of that dataset). We will use the `"train"` split for training and the
`"test"` split for validation.

To preprocess our data, we will need a tokenizer:

```python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
```

As we saw in [preprocessing](preprocessing), we can prepare the text inputs for the model with the following command (this is an
example, not a command you can execute):

```python
inputs = tokenizer(sentences, padding="max_length", truncation=True)
```

This will make all the samples have the maximum length the model can accept (here 512), either by padding or truncating
them.

However, we can instead apply these preprocessing steps to all the splits of our dataset at once by using the
`map` method:

```python
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
```

You can learn more about the map method or the other ways to preprocess the data in the 馃 Datasets [documentation](https://huggingface.co/docs/datasets/).

Next we will generate a small subset of the training and validation set, to enable faster training:

```python
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000)) 
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) 
full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]
```

In all the examples below, we will always use `small_train_dataset` and `small_eval_dataset`. Just replace
them by their _full_ equivalent to train or evaluate on the full dataset.

<a id='trainer'></a>

## Fine-tuning in PyTorch with the Trainer API

<Youtube id="nvBXf7s7vTI"/>

Since PyTorch does not provide a training loop, the 馃 Transformers library provides a [`Trainer`]
API that is optimized for 馃 Transformers models, with a wide range of training options and with built-in features like
logging, gradient accumulation, and mixed precision.

First, let's define our model:

```python
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
```

This will issue a warning about some of the pretrained weights not being used and some weights being randomly
initialized. That's because we are throwing away the pretraining head of the BERT model to replace it with a
classification head which is randomly initialized. We will fine-tune this model on our task, transferring the knowledge
of the pretrained model to it (which is why doing this is called transfer learning).

Then, to define our [`Trainer`], we will need to instantiate a
[`TrainingArguments`]. This class contains all the hyperparameters we can tune for the
[`Trainer`] or the flags to activate the different training options it supports. Let's begin by
using all the defaults, the only thing we then have to provide is a directory in which the checkpoints will be saved:

```python
from transformers import TrainingArguments

training_args = TrainingArguments("test_trainer")
```

Then we can instantiate a [`Trainer`] like this:

```python
from transformers import Trainer

trainer = Trainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset
)
```

To fine-tune our model, we just need to call

```python
trainer.train()
```

which will start a training that you can follow with a progress bar, which should take a couple of minutes to complete
(as long as you have access to a GPU). It won't actually tell you anything useful about how well (or badly) your model
is performing however as by default, there is no evaluation during training, and we didn't tell the
[`Trainer`] to compute any metrics. Let's have a look on how to do that now!

To have the [`Trainer`] compute and report metrics, we need to give it a `compute_metrics`
function that takes predictions and labels (grouped in a namedtuple called [`EvalPrediction`]) and
return a dictionary with string items (the metric names) and float values (the metric values).

The 馃 Datasets library provides an easy way to get the common metrics used in NLP with the `load_metric` function.
here we simply use accuracy. Then we define the `compute_metrics` function that just convert logits to predictions
(remember that all 馃 Transformers models return the logits) and feed them to `compute` method of this metric.

```python
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)
```

The compute function needs to receive a tuple (with logits and labels) and has to return a dictionary with string keys
(the name of the metric) and float values. It will be called at the end of each evaluation phase on the whole arrays of
predictions/labels.

To check if this works on practice, let's create a new [`Trainer`] with our fine-tuned model:

```python
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.evaluate()
```

which showed an accuracy of 87.5% in our case.

If you want to fine-tune your model and regularly report the evaluation metrics (for instance at the end of each
epoch), here is how you should define your training arguments:

```python
from transformers import TrainingArguments

training_args = TrainingArguments("test_trainer", evaluation_strategy="epoch")
```

See the documentation of [`TrainingArguments`] for more options.


<a id='keras'></a>

## Fine-tuning with Keras

<Youtube id="rnTGBy2ax1c"/>

Models can also be trained natively in TensorFlow using the Keras API. First, let's define our model:

```python
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification

model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
```

Then we will need to convert our datasets from before in standard `tf.data.Dataset`. Since we have fixed shapes,
it can easily be done like this. First we remove the _"text"_ column from our datasets and set them in TensorFlow
format:

```python
tf_train_dataset = small_train_dataset.remove_columns(["text"]).with_format("tensorflow")
tf_eval_dataset = small_eval_dataset.remove_columns(["text"]).with_format("tensorflow")
```

Then we convert everything in big tensors and use the `tf.data.Dataset.from_tensor_slices` method:

```python
train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names}
train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"]))
train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(8)

eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names}
eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"]))
eval_tf_dataset = eval_tf_dataset.batch(8)
```

With this done, the model can then be compiled and trained as any Keras model:

```python
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=tf.metrics.SparseCategoricalAccuracy(),
)

model.fit(train_tf_dataset, validation_data=eval_tf_dataset, epochs=3)
```

With the tight interoperability between TensorFlow and PyTorch models, you can even save the model and then reload it
as a PyTorch model (or vice-versa):

```python
from transformers import AutoModelForSequenceClassification

model.save_pretrained("my_imdb_model")
pytorch_model = AutoModelForSequenceClassification.from_pretrained("my_imdb_model", from_tf=True)
```

<a id='pytorch_native'></a>

## Fine-tuning in native PyTorch

<Youtube id="Dh9CL8fyG80"/>

You might need to restart your notebook at this stage to free some memory, or execute the following code:

```python
del model
del pytorch_model
del trainer
torch.cuda.empty_cache()
```

Let's now see how to achieve the same results as in [trainer section](#trainer) in PyTorch. First we need to
define the dataloaders, which we will use to iterate over batches. We just need to apply a bit of post-processing to
our `tokenized_datasets` before doing that to:

- remove the columns corresponding to values the model does not expect (here the `"text"` column)
- rename the column `"label"` to `"labels"` (because the model expect the argument to be named `labels`)
- set the format of the datasets so they return PyTorch Tensors instead of lists.

Our _tokenized_datasets_ has one method for each of those steps:

```python
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
```

Now that this is done, we can easily define our dataloaders:

```python
from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)
```

Next, we define our model:

```python
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
```

We are almost ready to write our training loop, the only two things are missing are an optimizer and a learning rate
scheduler. The default optimizer used by the [`Trainer`] is [`AdamW`]:

```python
from transformers import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)
```

Finally, the learning rate scheduler used by default is just a linear decay from the maximum value (5e-5 here) to 0:

```python
from transformers import get_scheduler

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)
```

One last thing, we will want to use the GPU if we have access to one (otherwise training might take several hours
instead of a couple of minutes). To do this, we define a `device` we will put our model and our batches on.

```python
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
```

We now are ready to train! To get some sense of when it will be finished, we add a progress bar over our number of
training steps, using the _tqdm_ library.

```python
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
```

Note that if you are used to freezing the body of your pretrained model (like in computer vision) the above may seem a
bit strange, as we are directly fine-tuning the whole model without taking any precaution. It actually works better
this way for Transformers model (so this is not an oversight on our side). If you're not familiar with what "freezing
the body" of the model means, forget you read this paragraph.

Now to check the results, we need to write the evaluation loop. Like in the [trainer section](#trainer) we will
use a metric from the datasets library. Here we accumulate the predictions at each batch before computing the final
result when the loop is finished.

```python
metric= load_metric("accuracy")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()
```

<a id='additional-resources'></a>

## Additional resources

To look at more fine-tuning examples you can refer to:

- [馃 Transformers Examples](https://github.com/huggingface/transformers/tree/master/examples) which includes scripts
  to train on all common NLP tasks in PyTorch and TensorFlow.

- [馃 Transformers Notebooks](notebooks) which contains various notebooks and in particular one per task (look for
  the _how to finetune a model on xxx_).