Unverified Commit 724e51c6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Compute loss independent from decoder for TF EncDec models (as #14139) (#15175)



* Compute loss independent from decoder (as 14139)

* fix expected seq_len + style

* Apply the same change to TFVisionEncoderDecoderModel

* fix style

* Add case with labels in equivalence test

* uncomment

* Add case with labels in equivalence test

* add decoder_token_labels

* use hf_compute_loss

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add copied from
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 3d5dea9b
......@@ -16,6 +16,7 @@
import tempfile
import warnings
from typing import Optional
import tensorflow as tf
......@@ -29,7 +30,13 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
get_initializer,
input_processing,
shape_list,
)
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
......@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)
ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
......@@ -145,8 +159,36 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel):
class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
r"""
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the
......@@ -566,6 +608,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
decoder_processing_inputs = {
"func": self.decoder.call,
"config": self.decoder.config,
......@@ -574,7 +621,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask,
"inputs_embeds": decoder_inputs_embeds,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"use_cache": use_cache,
......@@ -592,12 +638,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
past_key_values = None
logits = decoder_outputs[0]
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
past_key_values = None
if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
......@@ -611,7 +662,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
return output
return TFSeq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
......@@ -693,6 +744,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
......
......@@ -16,6 +16,7 @@
import tempfile
import warnings
from typing import Optional
import tensorflow as tf
......@@ -29,7 +30,13 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
get_initializer,
input_processing,
shape_list,
)
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
......@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
......@@ -134,8 +148,37 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class TFVisionEncoderDecoderModel(TFPreTrainedModel):
class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
r"""
[`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
with one of the base vision model classes of the library as encoder and another one of the base model classes as
......@@ -594,6 +637,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
batch_size, sequence_length = shape_list(encoder_hidden_states)[:2]
encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32)
......@@ -605,7 +653,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"inputs_embeds": decoder_inputs_embeds,
"labels": labels,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"use_cache": use_cache,
......@@ -622,12 +669,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0]
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
past_key_values = None
logits = decoder_outputs[0]
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
past_key_values = None
if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2]
past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
......@@ -641,7 +693,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
return output
return TFSeq2SeqLMOutput(
loss=decoder_outputs.loss,
loss=loss,
logits=decoder_outputs.logits,
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
......@@ -715,6 +767,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import copy
import os
import tempfile
import unittest
......@@ -237,7 +238,7 @@ class TFEncoderDecoderMixin:
)
# Make sure `loss` exist
assert "loss" in outputs_encoder_decoder
self.assertIn("loss", outputs_encoder_decoder)
batch_size, seq_len = decoder_input_ids.shape
expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
......@@ -319,12 +320,18 @@ class TFEncoderDecoderMixin:
# prepare inputs
tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
tf_outputs = tf_model(**inputs_dict).to_tuple()
tf_outputs = tf_model(**inputs_dict)
if "loss" in tf_outputs:
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
tf_outputs = tf_outputs.to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
......@@ -339,8 +346,12 @@ class TFEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
if "loss" in tf_outputs_loaded:
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
......@@ -435,6 +446,8 @@ class TFEncoderDecoderMixin:
def test_pt_tf_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
# Keep only common arguments
arg_names = [
"config",
......@@ -454,6 +467,9 @@ class TFEncoderDecoderMixin:
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
inputs_dict_with_labels = copy.copy(inputs_dict)
inputs_dict_with_labels["labels"] = labels
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant(
......@@ -471,6 +487,10 @@ class TFEncoderDecoderMixin:
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# check equivalence with labels
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected
......
......@@ -15,6 +15,7 @@
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
import copy
import os
import tempfile
import unittest
......@@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin:
# prepare inputs
tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
tf_outputs = tf_model(**inputs_dict).to_tuple()
tf_outputs = tf_model(**inputs_dict)
if "loss" in tf_outputs:
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
tf_outputs = tf_outputs.to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
......@@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
tf_outputs_loaded = tf_model_loaded(**inputs_dict)
if "loss" in tf_outputs_loaded:
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
......@@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin:
def test_pt_tf_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
# Keep only common arguments
arg_names = [
"config",
......@@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin:
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
inputs_dict_with_labels = copy.copy(inputs_dict)
inputs_dict_with_labels["labels"] = labels
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant(
......@@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin:
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# check equivalence with labels
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected
......@@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"decoder_token_labels": decoder_token_labels,
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
"labels": decoder_token_labels,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment