Unverified Commit 5f80c15e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix memory regression in Seq2Seq example (#9713)

* Fix memory regression in Seq2Seq example

* Fix test and properly deal with -100

* Easier condition with device safety

* Patch for MBartTokenzierFast
parent a7dabfb3
...@@ -26,6 +26,7 @@ from transformers import ( ...@@ -26,6 +26,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
MBartTokenizer, MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
set_seed, set_seed,
...@@ -220,11 +221,14 @@ def main(): ...@@ -220,11 +221,14 @@ def main():
data_args.eval_beams = model.config.num_beams data_args.eval_beams = model.config.num_beams
# set decoder_start_token_id for MBart # set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert ( assert (
data_args.tgt_lang is not None and data_args.src_lang is not None data_args.tgt_lang is not None and data_args.src_lang is not None
), "mBart requires --tgt_lang and --src_lang" ), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.tgt_lang)
if model_args.freeze_embeds: if model_args.freeze_embeds:
freeze_embeds(model) freeze_embeds(model)
...@@ -284,7 +288,9 @@ def main(): ...@@ -284,7 +288,9 @@ def main():
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), data_collator=Seq2SeqDataCollator(
tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores
),
compute_metrics=compute_metrics_fn, compute_metrics=compute_metrics_fn,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
......
...@@ -33,8 +33,9 @@ from torch import nn ...@@ -33,8 +33,9 @@ from torch import nn
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try: try:
...@@ -274,9 +275,10 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -274,9 +275,10 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class Seq2SeqDataCollator: class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None): def __init__(self, tokenizer, data_args, decoder_start_token_id, tpu_num_cores=None):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id self.pad_token_id = tokenizer.pad_token_id
self.decoder_start_token_id = decoder_start_token_id
assert ( assert (
self.pad_token_id is not None self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
...@@ -304,9 +306,15 @@ class Seq2SeqDataCollator: ...@@ -304,9 +306,15 @@ class Seq2SeqDataCollator:
labels = trim_batch(labels, self.pad_token_id) labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id, self.decoder_start_token_id)
batch = { batch = {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels, "labels": labels,
} }
return batch return batch
......
...@@ -1297,14 +1297,18 @@ class Trainer: ...@@ -1297,14 +1297,18 @@ class Trainer:
Subclass and override for custom behavior. Subclass and override for custom behavior.
""" """
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs) outputs = model(**inputs)
# Save past state if it exists # Save past state if it exists
# TODO: this needs to be fixed and made cleaner later. # TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if self.label_smoother is not None and "labels" in inputs: if labels is not None:
return self.label_smoother(outputs, inputs["labels"]) return self.label_smoother(outputs, labels)
else: else:
# We don't use .loss here since the model may return tuples instead of ModelOutput. # We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0] return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
......
...@@ -380,17 +380,26 @@ class LabelSmoother: ...@@ -380,17 +380,26 @@ class LabelSmoother:
ignore_index: int = -100 ignore_index: int = -100
def __call__(self, model_output, labels): def __call__(self, model_output, labels):
model_loss = model_output["loss"] if isinstance(model_output, dict) else model_output[0] logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[1]
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1) log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1)
# Look at the ignored index and mask the corresponding log_probs. padding_mask = labels.eq(self.ignore_index)
padding_mask = labels.unsqueeze(-1).eq(self.ignore_index) # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
log_probs.masked_fill_(padding_mask, 0.0) # will ignore them in any case.
labels.clamp_min_(0)
nll_loss = log_probs.gather(dim=-1, index=labels)
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
nll_loss.masked_fill_(padding_mask, 0.0)
smoothed_loss.masked_fill_(padding_mask, 0.0)
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded): # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
smoothed_loss = log_probs.mean(dim=-1).sum() / (padding_mask.numel() - padding_mask.long().sum()) num_active_elements = padding_mask.numel() - padding_mask.long().sum()
return (1 - self.epsilon) * model_loss + self.epsilon * smoothed_loss nll_loss = nll_loss.sum() / num_active_elements
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
......
...@@ -71,7 +71,7 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -71,7 +71,7 @@ class TrainerUtilsTest(unittest.TestCase):
random_logits = torch.randn(4, 5, num_labels) random_logits = torch.randn(4, 5, num_labels)
random_labels = torch.randint(0, num_labels, (4, 5)) random_labels = torch.randint(0, num_labels, (4, 5))
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) model_output = SequenceClassifierOutput(logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean() expected_loss = (1 - epsilon) * loss + epsilon * log_probs.mean()
...@@ -83,7 +83,7 @@ class TrainerUtilsTest(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TrainerUtilsTest(unittest.TestCase):
random_labels[2, 3] = -100 random_labels[2, 3] = -100
loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1)) loss = torch.nn.functional.cross_entropy(random_logits.view(-1, num_labels), random_labels.view(-1))
model_output = SequenceClassifierOutput(loss=loss, logits=random_logits) model_output = SequenceClassifierOutput(logits=random_logits)
label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels) label_smoothed_loss = LabelSmoother(0.1)(model_output, random_labels)
log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1) log_probs = -torch.nn.functional.log_softmax(random_logits, dim=-1)
# Mask the log probs with the -100 labels # Mask the log probs with the -100 labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment