"tests/modeling_tf_roberta_test.py" did not exist on "35401fe50fa3e460b2a4422630b017f106c79e03"
Unverified Commit ac12a5ae authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix EncoderDecoderModel classes to be more like BART and T5 (#14139)

* First draft

* Make tuple output more readable

* Replace assertions by value errors

* Make it possible to predict_with_generate for vision and speech models

* Adapt Seq2SeqTrainer to work with VisionEncoderDecoder/SpeechEncoderDecoder

* Add deprecation warning

* Add copied from statements to vision and speech encoder decoders

* Fix failing test

* Apply @patrickvonplaten's suggestion

* Use reshape instead of view for consistency
parent 1251072f
......@@ -14,9 +14,12 @@
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """
import warnings
from typing import Optional
import torch
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import Seq2SeqLMOutput
......@@ -29,6 +32,13 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)
ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
......@@ -136,6 +146,24 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
r"""
......@@ -434,6 +462,11 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states = encoder_outputs[0]
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
......@@ -441,7 +474,6 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
......@@ -450,11 +482,22 @@ class EncoderDecoderModel(PreTrainedModel):
**kwargs_decoder,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
......@@ -465,6 +508,9 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
......
......@@ -17,7 +17,9 @@
from typing import Optional
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
......@@ -149,6 +151,25 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
class SpeechEncoderDecoderModel(PreTrainedModel):
r"""
......@@ -467,6 +488,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
else:
encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
......@@ -482,20 +508,34 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
**kwargs_decoder,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
......
......@@ -17,7 +17,9 @@
from typing import Optional
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
......@@ -29,6 +31,25 @@ from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
......@@ -448,6 +469,11 @@ class VisionEncoderDecoderModel(PreTrainedModel):
# else:
encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
......@@ -455,7 +481,6 @@ class VisionEncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
......@@ -464,21 +489,34 @@ class VisionEncoderDecoderModel(PreTrainedModel):
**kwargs_decoder,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
......
......@@ -164,9 +164,15 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
}
if self.tokenizer is not None:
generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names}
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = inputs["input_ids"]
generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generation_inputs,
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
......@@ -197,15 +203,16 @@ class Seq2SeqTrainer(Trainer):
return (loss, generated_tokens, labels)
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is None:
raise ValueError(
f"Tensor need to be padded to `max_length={max_length}` but no tokenizer was passed when creating "
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
)
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment