Unverified Commit ac12a5ae authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix EncoderDecoderModel classes to be more like BART and T5 (#14139)

* First draft

* Make tuple output more readable

* Replace assertions by value errors

* Make it possible to predict_with_generate for vision and speech models

* Adapt Seq2SeqTrainer to work with VisionEncoderDecoder/SpeechEncoderDecoder

* Add deprecation warning

* Add copied from statements to vision and speech encoder decoders

* Fix failing test

* Apply @patrickvonplaten's suggestion

* Use reshape instead of view for consistency
parent 1251072f
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
# limitations under the License. # limitations under the License.
""" Classes to support Encoder-Decoder architectures """ """ Classes to support Encoder-Decoder architectures """
import warnings
from typing import Optional from typing import Optional
import torch
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_outputs import Seq2SeqLMOutput
...@@ -29,6 +32,13 @@ logger = logging.get_logger(__name__) ...@@ -29,6 +32,13 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig" _CONFIG_FOR_DOC = "EncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)
ENCODER_DECODER_START_DOCSTRING = r""" ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
...@@ -136,6 +146,24 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -136,6 +146,24 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
""" """
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel): class EncoderDecoderModel(PreTrainedModel):
r""" r"""
...@@ -434,6 +462,11 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -434,6 +462,11 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
...@@ -441,7 +474,6 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -441,7 +474,6 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache, use_cache=use_cache,
...@@ -450,11 +482,22 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -450,11 +482,22 @@ class EncoderDecoderModel(PreTrainedModel):
**kwargs_decoder, **kwargs_decoder,
) )
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput( return Seq2SeqLMOutput(
loss=decoder_outputs.loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
...@@ -465,6 +508,9 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -465,6 +508,9 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
from typing import Optional from typing import Optional
import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
...@@ -149,6 +151,25 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -149,6 +151,25 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
@add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
class SpeechEncoderDecoderModel(PreTrainedModel): class SpeechEncoderDecoderModel(PreTrainedModel):
r""" r"""
...@@ -467,6 +488,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -467,6 +488,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
else: else:
encoder_attention_mask = None encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
...@@ -482,20 +508,34 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -482,20 +508,34 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
**kwargs_decoder, **kwargs_decoder,
) )
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput( return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
from typing import Optional from typing import Optional
import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
...@@ -29,6 +31,25 @@ from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM ...@@ -29,6 +31,25 @@ from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig" _CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
...@@ -448,6 +469,11 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -448,6 +469,11 @@ class VisionEncoderDecoderModel(PreTrainedModel):
# else: # else:
encoder_attention_mask = None encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode # Decode
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
...@@ -455,7 +481,6 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -455,7 +481,6 @@ class VisionEncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache, use_cache=use_cache,
...@@ -464,21 +489,34 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -464,21 +489,34 @@ class VisionEncoderDecoderModel(PreTrainedModel):
**kwargs_decoder, **kwargs_decoder,
) )
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput( return Seq2SeqLMOutput(
loss=decoder_outputs.loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
......
...@@ -164,9 +164,15 @@ class Seq2SeqTrainer(Trainer): ...@@ -164,9 +164,15 @@ class Seq2SeqTrainer(Trainer):
"synced_gpus": True if is_deepspeed_zero3_enabled() else False, "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
} }
if self.tokenizer is not None:
generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names}
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = inputs["input_ids"]
generated_tokens = self.model.generate( generated_tokens = self.model.generate(
inputs["input_ids"], **generation_inputs,
attention_mask=inputs["attention_mask"],
**gen_kwargs, **gen_kwargs,
) )
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
...@@ -197,15 +203,16 @@ class Seq2SeqTrainer(Trainer): ...@@ -197,15 +203,16 @@ class Seq2SeqTrainer(Trainer):
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
def _pad_tensors_to_max_len(self, tensor, max_length): def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is None: if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
raise ValueError(
f"Tensor need to be padded to `max_length={max_length}` but no tokenizer was passed when creating "
"this `Trainer`. Make sure to create your `Trainer` with the appropriate tokenizer."
)
# If PAD token is not defined at least EOS token has to be defined # If PAD token is not defined at least EOS token has to be defined
pad_token_id = ( pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
) )
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones( padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment