Unverified Commit 3c682ea1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Examples] Allow EncoderDecoderModels to be trained with Seq2Seq (#7809)

* Make Seq2Seq Trainer more similar to Trainer

* fix typo

* fix seq2seq trainer

* remove from tests

* remove lock

* remove train files

* delete test files

* correct typo

* check at init

* make sure trainer is not slowed down on TPU

* correct isort

* remove use cache

* fix use cache

* add last use chache = false
parent 59b5953d
......@@ -16,7 +16,6 @@ from transformers import (
)
from transformers.trainer_utils import EvaluationStrategy
from utils import (
LegacySeq2SeqDataset,
Seq2SeqDataCollator,
Seq2SeqDataset,
assert_all_frozen,
......@@ -138,6 +137,10 @@ class DataTrainingArguments:
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
ignore_pad_token_for_loss: bool = field(
default=True,
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
)
def main():
......@@ -223,7 +226,7 @@ def main():
freeze_params(model.get_encoder())
assert_all_frozen(model.get_encoder())
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
dataset_class = Seq2SeqDataset
# Get datasets
train_dataset = (
......
import logging
import copy
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers import PreTrainedModel, Trainer, logging
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import (
......@@ -27,7 +27,7 @@ except ImportError:
from utils import label_smoothed_nll_loss
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
......@@ -41,13 +41,25 @@ arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
class Seq2SeqTrainer(Trainer):
def __init__(self, config, data_args, *args, **kwargs):
def __init__(self, config=None, data_args=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if config is None:
assert isinstance(
self.model, PreTrainedModel
), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}"
self.config = self._actual_model(self.model).config
else:
self.config = config
self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length
self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size
if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss):
assert (
self.config.pad_token_id is not None
), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing."
def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Setup the optimizer and the learning rate scheduler.
......@@ -114,23 +126,31 @@ class Seq2SeqTrainer(Trainer):
else DistributedSampler(self.train_dataset)
)
def compute_loss(self, model, inputs):
def _compute_loss(self, model, inputs):
inputs = copy.deepcopy(inputs)
if self.args.label_smoothing == 0:
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token
labels = inputs.pop("labels")
outputs = model(**inputs, use_cache=False)
logits = outputs[0]
return self._compute_loss(logits, labels)
logits = model(**inputs, use_cache=False)[0]
def _compute_loss(self, logits, labels):
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
assert logits.shape[-1] == self.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
# compute usual loss via models
loss, logits = model(**inputs, use_cache=False)[:2]
else:
# compute label smoothed loss
labels = inputs.pop("labels")
logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
return loss, logits
def compute_loss(self, model, inputs):
loss, _ = self._compute_loss(model, inputs)
return loss
def prediction_step(
......@@ -158,31 +178,37 @@ class Seq2SeqTrainer(Trainer):
"""
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
if self.args.predict_with_generate and not self.args.prediction_loss_only:
gen_kwargs = {
"max_length": self.data_args.val_max_target_length
if self.data_args is not None
else self.config.max_length,
"num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams,
}
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
num_beams=self.data_args.eval_beams,
max_length=self.max_gen_length,
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, self.max_gen_length)
if self.config.pad_token_id is not None:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
# compute loss on predict data
with torch.no_grad():
loss, logits = self._compute_loss(model, inputs)
labels_out = inputs.get("labels")
# Call forward again to get loss # TODO: avoidable?
outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(outputs[1], labels_out)
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)
logits = generated_tokens if self.args.predict_with_generate else outputs[1]
logits = generated_tokens if self.args.predict_with_generate else logits
labels = inputs["labels"]
if self.config.pad_token_id is not None:
labels = self._pad_tensors_to_max_len(labels, self.config.max_length)
labels_out = labels_out.detach()
labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length)
return (loss, logits.detach(), labels)
return (loss, logits, labels)
def _pad_tensors_to_max_len(self, tensor, max_length):
padded_tensor = self.config.pad_token_id * torch.ones(
......
......@@ -5,12 +5,14 @@ from unittest.mock import patch
import pytest
from transformers import is_torch_available
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import main
from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .test_seq2seq_examples import MBART_TINY
from .utils import execute_async_std
......@@ -50,6 +52,117 @@ class TestFinetuneTrainer(TestCasePlus):
assert "test_generations.txt" in contents
assert "test_results.json" in contents
@slow
def test_finetune_bert2bert(self):
if not is_datasets_available():
return
import datasets
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))
rouge = datasets.load_metric("rouge")
batch_size = 4
def _map_to_encoder_decoder_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask
assert all([len(x) == 512 for x in inputs.input_ids])
assert all([len(x) == 128 for x in outputs.input_ids])
return batch
def _compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid
return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
# map train dataset
train_dataset = train_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
train_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
# same for validation dataset
val_dataset = val_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
val_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)
output_dir = self.get_auto_remove_tmp_dir()
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
evaluate_during_training=True,
do_train=True,
do_eval=True,
warmup_steps=0,
eval_steps=2,
logging_steps=2,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
# start training
trainer.train()
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
# XXX: remove hardcoded path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment