"docs/vscode:/vscode.git/clone" did not exist on "5ac8b62265efac24f0dbfab271d2bce534179993"
Unverified Commit d1d15d6f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples (seq2seq)] fix preparing decoder_input_ids for T5 (#5994)

parent 3deffc1d
......@@ -14,7 +14,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
try:
......@@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
else:
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
if self.hparams.label_smoothing == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment