Unverified Commit d1d15d6f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples (seq2seq)] fix preparing decoder_input_ids for T5 (#5994)

parent 3deffc1d
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, get_linear_schedule_with_warmup from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
try: try:
...@@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer): ...@@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple: def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone? if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
else:
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False) outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment