Unverified Commit 5b44e0a3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[T5] Add training documenation (#3507)

* Add clear description of how to train T5

* correct docstring in T5

* correct typo

* correct docstring format

* update t5 model docs

* implement collins feedback

* fix typo and add more explanation for sentinal tokens

* delete unnecessary todos
parent 33ef7002
...@@ -16,13 +16,45 @@ To facilitate future work on transfer learning for NLP, we release our dataset, ...@@ -16,13 +16,45 @@ To facilitate future work on transfer learning for NLP, we release our dataset,
The Authors' code can be found `here <https://github.com/google-research/text-to-text-transfer-transformer>`_ . The Authors' code can be found `here <https://github.com/google-research/text-to-text-transfer-transformer>`_ .
Training
~~~~~~~~~~~~~~~~~~~~
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher forcing.
This means that for training we always need an input sequence and a target sequence.
The input sequence is fed to the model using ``input_ids``. The target sequence is shifted to the right, *i.e.* perprended by a start-sequence token and fed to the decoder using the `decoder_input_ids`. In teacher-forcing style, the target sequence is then appended by the EOS token and corresponds to the ``lm_labels``. The PAD token is hereby used as the start-sequence token.
T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
- Unsupervised denoising training
In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens)
and the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens.
Each sentinel tokens represents a unique mask token for this sentence and should start with ``<extra_id_1>``, ``<extrac_id_2>``, ... up to ``<extra_id_100>``. As a default 100 sentinel tokens are available in ``T5Tokenizer``.
*E.g.* the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be processed as follows:
::
input_ids = tokenizer.encode('The <extra_id_1> walks in <extra_id_2> park', return_tensors='pt')
lm_labels = tokenizer.encode('<extra_id_1> cute dog <extra_id_2> the <extra_id_3> </s>', return_tensors='pt')
# the forward function automatically creates the correct decoder_input_ids
model(input_ids=input_ids, lm_labels=lm_labels)
- Supervised training
In this setup the input sequence and output sequence are standard sequence to sequence input output mapping.
In translation, *e.g.* the input sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar." should
be processed as follows:
::
input_ids = tokenizer.encode('translate English to German: The house is wonderful. </s>', return_tensors='pt')
lm_labels = tokenizer.encode('Das Haus ist wunderbar. </s>', return_tensors='pt')
# the forward function automatically creates the correct decoder_input_ids
model(input_ids=input_ids, lm_labels=lm_labels)
Tips Tips
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised - T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised
and supervised tasks and which each task is cast as a sequence to sequence task. and supervised tasks and for which each task is converted into a text-to-text format.
Therefore T5 works well on a variety of tasks out-of-the-box by prepending a different prefix to the input corresponding to each task, e.g.: for translation: *translate English to German: ..., summarize: ...*. T5 works well on a variety of tasks out-of-the-box by prepending a different prefix to the input corresponding to each task, e.g.: for translation: *translate English to German: ..., summarize: ...*.
For more information about the which prefix to use, it is easiest to look into Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`_ . For more information about which prefix to use, it is easiest to look into Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`_ .
- For sequence to sequence generation, it is recommended to use ``T5ForConditionalGeneration.generate()``. The method takes care of feeding the encoded input via cross-attention layers to the decoder and auto-regressively generating the decoder output. - For sequence to sequence generation, it is recommended to use ``T5ForConditionalGeneration.generate()``. The method takes care of feeding the encoded input via cross-attention layers to the decoder and auto-regressively generates the decoder output.
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right. - T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
......
...@@ -73,7 +73,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -73,7 +73,7 @@ BART_INPUTS_DOCSTRING = r"""
Mask to avoid performing attention on padding token indices in input_ids. Mask to avoid performing attention on padding token indices in input_ids.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (tuple(:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder. Used in the cross-attention of the decoder.
......
...@@ -699,32 +699,24 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -699,32 +699,24 @@ T5_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
T5 is a model with relative position embeddings so you should be able to pad the inputs on T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left.
Indices can be obtained using :class:`transformers.T5Tokenizer`. Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
To know more on how to prepare :obj:`input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (tuple(:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`): encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder. Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
......
...@@ -637,27 +637,18 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -637,27 +637,18 @@ T5_INPUTS_DOCSTRING = r"""
input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
T5 is a model with relative position embeddings so you should be able to pad the inputs on T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left. the right or the left.
Indices can be obtained using :class:`transformers.T5Tokenizer`. Indices can be obtained using :class:`transformers.T5Tokenizer`.
To know more on how to prepare :obj:`input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (tuple(:obj:`tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`): encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`) Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder. `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder. Used in the cross-attention of the decoder.
...@@ -671,6 +662,8 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -671,6 +662,8 @@ T5_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules. Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -836,7 +829,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -836,7 +829,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
model = TFT5ForConditionalGeneration.from_pretrained('t5-small') model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1 input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1
outputs = model(input_ids, input_ids=input_ids, lm_labels=input_ids) outputs = model(input_ids, input_ids=input_ids, lm_labels=input_ids)
prediction_scores = outputs[:1] # TODO: TFT5 still needs to implement prediction_scores = outputs[:1]
tokenizer = T5Tokenizer.from_pretrained('t5-small') tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = TFT5ForConditionalGeneration.from_pretrained('t5-small') model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
...@@ -879,7 +872,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): ...@@ -879,7 +872,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
) )
# TODO (thom / patrick): add lm_labels for loss function
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
embed_tokens = self.get_output_embeddings() embed_tokens = self.get_output_embeddings()
lm_logits = embed_tokens(sequence_output, mode="linear") lm_logits = embed_tokens(sequence_output, mode="linear")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment