Unverified Commit b24201fa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Doctests] Fix all T5 doc tests (#16646)



* [Doctests] Fix all T5 doc tests

* make style

* Update docs/source/en/model_doc/t5.mdx
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply Sylvains comments

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent f7196f2e
...@@ -48,37 +48,98 @@ fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix. ...@@ -48,37 +48,98 @@ fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
ByT5 works on raw UTF-8 bytes, so it can be used without a tokenizer: ByT5 works on raw UTF-8 bytes, so it can be used without a tokenizer:
```python ```python
from transformers import T5ForConditionalGeneration >>> from transformers import T5ForConditionalGeneration
import torch >>> import torch
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") >>> model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens >>> num_special_tokens = 3
labels = ( >>> # Model has 3 special tokens which take up the input ids 0,1,2 of ByT5.
torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 >>> # => Need to shift utf-8 character encodings by 3 before passing ids to model.
) # add 3 for special tokens
loss = model(input_ids, labels=labels).loss # forward pass >>> input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + num_special_tokens
>>> labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + num_special_tokens
>>> loss = model(input_ids, labels=labels).loss
>>> loss.item()
2.66
``` ```
For batched inference and training it is however recommended to make use of the tokenizer: For batched inference and training it is however recommended to make use of the tokenizer:
```python ```python
from transformers import T5ForConditionalGeneration, AutoTokenizer >>> from transformers import T5ForConditionalGeneration, AutoTokenizer
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") >>> model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") >>> tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
>>> model_inputs = tokenizer(
... ["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt"
... )
>>> labels_dict = tokenizer(
... ["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt"
... )
>>> labels = labels_dict.input_ids
>>> loss = model(**model_inputs, labels=labels).loss
>>> loss.item()
17.9
```
model_inputs = tokenizer( Similar to [T5](t5), ByT5 was trained on the span-mask denoising task. However,
["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt" since the model works directly on characters, the pretraining task is a bit
) different. Let's corrupt some characters of the
labels = tokenizer( input sentence `"The dog chases a ball in the park."` and ask ByT5 to predict them
["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt" for us.
).input_ids
loss = model(**model_inputs, labels=labels).loss # forward pass ```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("google/byt5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-base")
>>> input_ids_prompt = "The dog chases a ball in the park."
>>> input_ids = tokenizer(input_ids_prompt).input_ids
>>> # Note that we cannot add "{extra_id_...}" to the string directly
>>> # as the Byte tokenizer would incorrectly merge the tokens
>>> # For ByT5, we need to work directly on the character level
>>> # Contrary to T5, ByT5 does not use sentinel tokens for masking, but instead
>>> # uses final utf character ids.
>>> # UTF-8 is represented by 8 bits and ByT5 has 3 special tokens.
>>> # => There are 2**8+2 = 259 input ids and mask tokens count down from index 258.
>>> # => mask to "The dog [258]a ball [257]park."
>>> input_ids = torch.tensor([input_ids[:8] + [258] + input_ids[14:21] + [257] + input_ids[28:]])
>>> input_ids
tensor([[ 87, 107, 104, 35, 103, 114, 106, 35, 258, 35, 100, 35, 101, 100, 111, 111, 257, 35, 115, 100, 117, 110, 49, 1]])
>>> # ByT5 produces only one char at a time so we need to produce many more output characters here -> set `max_length=100`.
>>> output_ids = model.generate(input_ids, max_length=100)[0].tolist()
>>> output_ids
[0, 258, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 257, 35, 108, 113, 35, 119, 107, 104, 35, 103, 108, 118, 102, 114, 256, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49, 35, 87, 107, 104, 35, 103, 114, 106, 35, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 35, 100, 35, 101, 100, 111, 111, 35, 108, 113, 255, 35, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49]
>>> # ^- Note how 258 descends to 257, 256, 255
>>> # Now we need to split on the sentinel tokens, let's write a short loop for this
>>> output_ids_list = []
>>> start_token = 0
>>> sentinel_token = 258
>>> while sentinel_token in output_ids:
... split_idx = output_ids.index(sentinel_token)
... output_ids_list.append(output_ids[start_token:split_idx])
... start_token = split_idx
... sentinel_token -= 1
>>> output_ids_list.append(output_ids[start_token:])
>>> output_string = tokenizer.batch_decode(output_ids_list)
>>> output_string
['<pad>', 'is the one who does', ' in the disco', 'in the park. The dog is the one who does a ball in', ' in the park.']
``` ```
## ByT5Tokenizer ## ByT5Tokenizer
[[autodoc]] ByT5Tokenizer [[autodoc]] ByT5Tokenizer
......
...@@ -32,9 +32,9 @@ NLP, we release our dataset, pre-trained models, and code.* ...@@ -32,9 +32,9 @@ NLP, we release our dataset, pre-trained models, and code.*
Tips: Tips:
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which - T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*, different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
for summarization: *summarize: ...*. for summarization: *summarize: ...*.
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right. - T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
...@@ -83,130 +83,140 @@ language modeling head on top of the decoder. ...@@ -83,130 +83,140 @@ language modeling head on top of the decoder.
- Unsupervised denoising training - Unsupervised denoising training
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`, sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`,
`<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in `<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in
[`T5Tokenizer`]. [`T5Tokenizer`].
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
processed as follows: processed as follows:
```python ```python
from transformers import T5Tokenizer, T5ForConditionalGeneration >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small") >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
model = T5ForConditionalGeneration.from_pretrained("t5-small") >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids >>> # the forward function automatically creates the correct decoder_input_ids
labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids >>> loss = model(input_ids=input_ids, labels=labels).loss
# the forward function automatically creates the correct decoder_input_ids >>> loss.item()
loss = model(input_ids=input_ids, labels=labels).loss 3.7837
``` ```
If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples
directory. directory.
- Supervised training - Supervised training
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping. In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
the model as follows: the model as follows:
```python ```python
from transformers import T5Tokenizer, T5ForConditionalGeneration >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small") >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small") >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids >>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids >>> labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss >>> # the forward function automatically creates the correct decoder_input_ids
``` >>> loss = model(input_ids=input_ids, labels=labels).loss
>>> loss.item()
As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the 0.2542
`input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded ```
target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate `input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
during T5's pre-training. shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
However, the example above only shows a single training example. In practice, one trains deep learning models in English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one during T5's pre-training.
typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on However, the example above only shows a single training example. In practice, one trains deep learning models in
the task. batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index` the task.
of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss
`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index`
ignored. The code example below illustrates all of this. of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass
```python `attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
from transformers import T5Tokenizer, T5ForConditionalGeneration ignored. The code example below illustrates all of this.
import torch
```python
tokenizer = T5Tokenizer.from_pretrained("t5-small") >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("t5-small") >>> import torch
# the following 2 hyperparameters are task-specific >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
max_source_length = 512 >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
max_target_length = 128
>>> # the following 2 hyperparameters are task-specific
# Suppose we have the following 2 training examples: >>> max_source_length = 512
input_sequence_1 = "Welcome to NYC" >>> max_target_length = 128
output_sequence_1 = "Bienvenue à NYC"
>>> # Suppose we have the following 2 training examples:
input_sequence_2 = "HuggingFace is a company" >>> input_sequence_1 = "Welcome to NYC"
output_sequence_2 = "HuggingFace est une entreprise" >>> output_sequence_1 = "Bienvenue à NYC"
# encode the inputs >>> input_sequence_2 = "HuggingFace is a company"
task_prefix = "translate English to French: " >>> output_sequence_2 = "HuggingFace est une entreprise"
input_sequences = [input_sequence_1, input_sequence_2]
encoding = tokenizer( >>> # encode the inputs
[task_prefix + sequence for sequence in input_sequences], >>> task_prefix = "translate English to French: "
padding="longest", >>> input_sequences = [input_sequence_1, input_sequence_2]
max_length=max_source_length,
truncation=True, >>> encoding = tokenizer(
return_tensors="pt", ... [task_prefix + sequence for sequence in input_sequences],
) ... padding="longest",
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask ... max_length=max_source_length,
... truncation=True,
# encode the targets ... return_tensors="pt",
target_encoding = tokenizer( ... )
[output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
) >>> input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
labels = target_encoding.input_ids
>>> # encode the targets
# replace padding token id's of the labels by -100 >>> target_encoding = tokenizer(
labels = torch.tensor(labels) ... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
labels[labels == tokenizer.pad_token_id] = -100 ... )
>>> labels = target_encoding.input_ids
# forward pass
loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss >>> # replace padding token id's of the labels by -100 so it's ignored by the loss
``` >>> labels = torch.tensor(labels)
>>> labels[labels == tokenizer.pad_token_id] = -100
>>> # forward pass
>>> loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
>>> loss.item()
0.188
```
Additional training tips: Additional training tips:
- T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW - T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer. answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
- According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's (1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes
used). used).
- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
*pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding *pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
encountered during training thus significantly slowing down the training. only padding up to the longest example in a encountered during training thus significantly slowing down the training. only padding up to the longest example in a
batch) leads to very slow training on TPU. batch) leads to very slow training on TPU.
<a id='inference'></a> <a id='inference'></a>
...@@ -219,15 +229,15 @@ There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encode ...@@ -219,15 +229,15 @@ There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encode
generation works in general in encoder-decoder models. generation works in general in encoder-decoder models.
```python ```python
from transformers import T5Tokenizer, T5ForConditionalGeneration >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small") >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small") >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids >>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
outputs = model.generate(input_ids) >>> outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Das Haus ist wunderbar. Das Haus ist wunderbar.
``` ```
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
...@@ -236,31 +246,47 @@ Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when do ...@@ -236,31 +246,47 @@ Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when do
The example above only shows a single example. You can also do batched inference, like so: The example above only shows a single example. You can also do batched inference, like so:
```python ```python
from transformers import T5Tokenizer, T5ForConditionalGeneration >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small") >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small") >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> task_prefix = "translate English to German: "
>>> sentences = [
... "The house is wonderful.",
... "I like to work in NYC.",
>>> ] # use different length sentences to test batching
>>> inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
>>> output_sequences = model.generate(
... input_ids=inputs["input_ids"],
... attention_mask=inputs["attention_mask"],
... do_sample=False, # disable sampling to test if batching affects output
... )
>>> print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
```
# when generating, we will use the logits of right-most token to predict the next token Because T5 has been trained with the span-mask denoising objective,
# so the padding should be on the left it can be used to predict the sentinel (masked-out) tokens during inference.
tokenizer.padding_side = "left" The predicted tokens will then be placed between the sentinel tokens.
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
task_prefix = "translate English to German: " ```python
sentences = ["The house is wonderful.", "I like to work in NYC."] # use different length sentences to test batching >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
output_sequences = model.generate( >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids=inputs["input_ids"], >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
attention_mask=inputs["attention_mask"],
do_sample=False, # disable sampling to test if batching affects output
)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.'] >>> sequence_ids = model.generate(input_ids)
>>> sequences = tokenizer.batch_decode(sequence_ids)
>>> sequences
['<pad> <extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']
``` ```
<a id='scripts'></a> <a id='scripts'></a>
## Performance ## Performance
......
...@@ -20,9 +20,9 @@ repository by Colin Raffel et al. It's an improved version of the original T5 mo ...@@ -20,9 +20,9 @@ repository by Colin Raffel et al. It's an improved version of the original T5 mo
One can directly plug in the weights of T5v1.1 into a T5 model, like so: One can directly plug in the weights of T5v1.1 into a T5 model, like so:
```python ```python
from transformers import T5ForConditionalGeneration >>> from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base") >>> model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base")
``` ```
T5 Version 1.1 includes the following improvements compared to the original T5 model: T5 Version 1.1 includes the following improvements compared to the original T5 model:
......
docs/source/en/quicktour.mdx docs/source/en/quicktour.mdx
docs/source/en/task_summary.mdx docs/source/en/task_summary.mdx
docs/source/en/model_doc/speech_to_text.mdx docs/source/en/model_doc/speech_to_text.mdx
docs/source/en/model_doc/t5.mdx
docs/source/en/model_doc/t5v1_1.mdx
docs/source/en/model_doc/byt5.mdx
docs/source/en/model_doc/tapex.mdx docs/source/en/model_doc/tapex.mdx
src/transformers/generation_utils.py src/transformers/generation_utils.py
src/transformers/models/bart/modeling_bart.py src/transformers/models/bart/modeling_bart.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment