t5.mdx 16.1 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# T5

## Overview

The T5 model was presented in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/pdf/1910.10683.pdf) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.

The abstract from the paper is the following:

*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream
task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning
has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of
transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a
text-to-text format. Our systematic study compares pretraining objectives, architectures, unlabeled datasets, transfer
approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration
with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering
summarization, question answering, text classification, and more. To facilitate future work on transfer learning for
NLP, we release our dataset, pre-trained models, and code.*

Tips:

- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
  each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
  different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
  for summarization: *summarize: ...*.

- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.

- See the [training](#training), [inference](#inference) and [scripts](#scripts) sections below for all details regarding usage.

T5 comes in different sizes:

- [t5-small](https://huggingface.co/t5-small)

- [t5-base](https://huggingface.co/t5-base)

- [t5-large](https://huggingface.co/t5-large)

- [t5-3b](https://huggingface.co/t5-3b)

- [t5-11b](https://huggingface.co/t5-11b).

Based on the original T5 model, Google has released some follow-up works:

- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
  mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found [here](t5v1.1).

- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
  the documentation of mT5 which can be found [here](mt5).

- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
  to the documentation of byT5 which can be found [here](byt5).

All checkpoints can be found on the [hub](https://huggingface.co/models?search=t5).

This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The original code can be found [here](https://github.com/google-research/text-to-text-transfer-transformer).

<a id='training'></a>

## Training

T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
sequence is fed to the model using `input_ids`. The target sequence is shifted to the right, i.e., prepended by a
start-sequence token and fed to the decoder using the `decoder_input_ids`. In teacher-forcing style, the target
sequence is then appended by the EOS token and corresponds to the `labels`. The PAD token is hereby used as the
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.

One can use [`T5ForConditionalGeneration`] (or the Tensorflow/Flax variant), which includes the
language modeling head on top of the decoder.

- Unsupervised denoising training

  In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
  the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
  sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`,
  `<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in
  [`T5Tokenizer`].

  For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
  processed as follows:

  ```python
  from transformers import T5Tokenizer, T5ForConditionalGeneration

  tokenizer = T5Tokenizer.from_pretrained("t5-small")
  model = T5ForConditionalGeneration.from_pretrained("t5-small")

Sylvain Gugger's avatar
Sylvain Gugger committed
101
102
  input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
Sylvain Gugger's avatar
Sylvain Gugger committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  # the forward function automatically creates the correct decoder_input_ids
  loss = model(input_ids=input_ids, labels=labels).loss
  ```

  If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling) script in the Examples
  directory.

- Supervised training

  In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
  Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
  sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
  the model as follows:

  ```python
  from transformers import T5Tokenizer, T5ForConditionalGeneration

  tokenizer = T5Tokenizer.from_pretrained("t5-small")
  model = T5ForConditionalGeneration.from_pretrained("t5-small")

Sylvain Gugger's avatar
Sylvain Gugger committed
123
124
  input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
  labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
Sylvain Gugger's avatar
Sylvain Gugger committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  # the forward function automatically creates the correct decoder_input_ids
  loss = model(input_ids=input_ids, labels=labels).loss
  ```

  As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the
  `input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded
  target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
  shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is
  equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
  English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
  during T5's pre-training.

  However, the example above only shows a single training example. In practice, one trains deep learning models in
  batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
  typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
  input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
  the task.

  In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss
  function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index`
  of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
  the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/master/examples/flax/summarization) for details). We also pass
  `attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
  ignored. The code example below illustrates all of this.

  ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
151
  from transformers import T5Tokenizer, T5ForConditionalGeneration
Sylvain Gugger's avatar
Sylvain Gugger committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  import torch

  tokenizer = T5Tokenizer.from_pretrained("t5-small")
  model = T5ForConditionalGeneration.from_pretrained("t5-small")

  # the following 2 hyperparameters are task-specific
  max_source_length = 512
  max_target_length = 128

  # Suppose we have the following 2 training examples:
  input_sequence_1 = "Welcome to NYC"
  output_sequence_1 = "Bienvenue 脿 NYC"

  input_sequence_2 = "HuggingFace is a company"
  output_sequence_2 = "HuggingFace est une entreprise"

  # encode the inputs
  task_prefix = "translate English to French: "
  input_sequences = [input_sequence_1, input_sequence_2]
Sylvain Gugger's avatar
Sylvain Gugger committed
171
172
173
174
175
176
177
  encoding = tokenizer(
      [task_prefix + sequence for sequence in input_sequences],
      padding="longest",
      max_length=max_source_length,
      truncation=True,
      return_tensors="pt",
  )
Sylvain Gugger's avatar
Sylvain Gugger committed
178
179
180
  input_ids, attention_mask = encoding.input_ids, encoding.attention_mask

  # encode the targets
Sylvain Gugger's avatar
Sylvain Gugger committed
181
182
183
  target_encoding = tokenizer(
      [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
  )
Sylvain Gugger's avatar
Sylvain Gugger committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
  labels = target_encoding.input_ids

  # replace padding token id's of the labels by -100
  labels = torch.tensor(labels)
  labels[labels == tokenizer.pad_token_id] = -100

  # forward pass
  loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
  ```

Additional training tips:

- T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW
  optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
  answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.

- According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when
  (1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
  pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes
  used).

- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
  *pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
  batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
  encountered during training thus significantly slowing down the training. only padding up to the longest example in a
  batch) leads to very slow training on TPU.

<a id='inference'></a>

## Inference

At inference time, it is recommended to use [`~generation_utils.GenerationMixin.generate`]. This
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
generation works in general in encoder-decoder models.

```python
Sylvain Gugger's avatar
Sylvain Gugger committed
222
from transformers import T5Tokenizer, T5ForConditionalGeneration
Sylvain Gugger's avatar
Sylvain Gugger committed
223
224
225
226

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

Sylvain Gugger's avatar
Sylvain Gugger committed
227
input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
Sylvain Gugger's avatar
Sylvain Gugger committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Das Haus ist wunderbar.
```

Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
[`~generation_utils.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.

The example above only shows a single example. You can also do batched inference, like so:

```python
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
Sylvain Gugger's avatar
Sylvain Gugger committed
246
247
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token  # to avoid an error
Sylvain Gugger's avatar
Sylvain Gugger committed
248

Sylvain Gugger's avatar
Sylvain Gugger committed
249
250
task_prefix = "translate English to German: "
sentences = ["The house is wonderful.", "I like to work in NYC."]  # use different length sentences to test batching
Sylvain Gugger's avatar
Sylvain Gugger committed
251
252
253
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)

output_sequences = model.generate(
Sylvain Gugger's avatar
Sylvain Gugger committed
254
255
256
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    do_sample=False,  # disable sampling to test if batching affects output
Sylvain Gugger's avatar
Sylvain Gugger committed
257
258
259
260
261
262
263
264
265
)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
```

<a id='scripts'></a>

266
267
268
269
270
## Performance

If you'd like a faster training and inference performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter.


Sylvain Gugger's avatar
Sylvain Gugger committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
## Example scripts

T5 is supported by several example scripts, both for pre-training and fine-tuning.

- pre-training: the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py)
  script allows you to further pre-train T5 or pre-train T5 from scratch on your own data. The [t5_tokenizer_model.py](https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/t5_tokenizer_model.py)
  script allows you to further train a T5 tokenizer or train a T5 Tokenizer from scratch on your own data. Note that
  Flax (a neural network library on top of JAX) is particularly useful to train on TPU hardware.

- fine-tuning: T5 is supported by the official summarization scripts ([PyTorch](https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization), [Tensorflow](https://github.com/huggingface/transformers/tree/master/examples/tensorflow/summarization), and [Flax](https://github.com/huggingface/transformers/tree/master/examples/flax/summarization)) and translation scripts
  ([PyTorch](https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation) and [Tensorflow](https://github.com/huggingface/transformers/tree/master/examples/tensorflow/translation)). These scripts allow
  you to easily fine-tune T5 on custom data for summarization/translation.

## T5Config

[[autodoc]] T5Config

## T5Tokenizer

[[autodoc]] T5Tokenizer
    - build_inputs_with_special_tokens
    - get_special_tokens_mask
    - create_token_type_ids_from_sequences
    - save_vocabulary

## T5TokenizerFast

[[autodoc]] T5TokenizerFast

## T5Model

[[autodoc]] T5Model
    - forward
    - parallelize
    - deparallelize

## T5ForConditionalGeneration

[[autodoc]] T5ForConditionalGeneration
    - forward
    - parallelize
    - deparallelize

## T5EncoderModel

[[autodoc]] T5EncoderModel
    - forward
    - parallelize
    - deparallelize

## TFT5Model

[[autodoc]] TFT5Model
    - call

## TFT5ForConditionalGeneration

[[autodoc]] TFT5ForConditionalGeneration
    - call

## TFT5EncoderModel

[[autodoc]] TFT5EncoderModel
    - call

## FlaxT5Model

[[autodoc]] FlaxT5Model
    - __call__
    - encode
    - decode

## FlaxT5ForConditionalGeneration

[[autodoc]] FlaxT5ForConditionalGeneration
    - __call__
    - encode
    - decode