"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2f3a4210185f5311f6cfab3c91b30616c9a30fc8"
Unverified Commit f85acb4d authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

Fix decode_input_ids to bare T5Model and improve doc (#18791)



* use tokenizer to output tensor

* add preprocessing for decoder_input_ids for bare T5Model

* add preprocessing to tf and flax

* linting

* linting

* Update src/transformers/models/t5/modeling_flax_t5.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/t5/modeling_tf_t5.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/t5/modeling_t5.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 3b19c031
......@@ -187,12 +187,15 @@ ignored. The code example below illustrates all of this.
>>> # encode the targets
>>> target_encoding = tokenizer(
... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
... [output_sequence_1, output_sequence_2],
... padding="longest",
... max_length=max_target_length,
... truncation=True,
... return_tensors="pt",
... )
>>> labels = target_encoding.input_ids
>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
>>> labels = torch.tensor(labels)
>>> labels[labels == tokenizer.pad_token_id] = -100
>>> # forward pass
......
......@@ -1388,6 +1388,10 @@ FLAX_T5_MODEL_DOCSTRING = """
... ).input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
......
......@@ -1383,6 +1383,10 @@ class T5Model(T5PreTrainedModel):
... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
......
......@@ -1180,6 +1180,10 @@ class TFT5Model(TFT5PreTrainedModel):
... ).input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
>>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment