"...composable_kernel_rocm.git" did not exist on "b76d3d9f96278b0d6eb17b72c37cb2353474638e"
Unverified Commit 724e51c6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Compute loss independent from decoder for TF EncDec models (as #14139) (#15175)



* Compute loss independent from decoder (as 14139)

* fix expected seq_len + style

* Apply the same change to TFVisionEncoderDecoderModel

* fix style

* Add case with labels in equivalence test

* uncomment

* Add case with labels in equivalence test

* add decoder_token_labels

* use hf_compute_loss

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add copied from
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 3d5dea9b
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import tempfile import tempfile
import warnings
from typing import Optional from typing import Optional
import tensorflow as tf import tensorflow as tf
...@@ -29,7 +30,13 @@ from ...file_utils import ( ...@@ -29,7 +30,13 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
get_initializer,
input_processing,
shape_list,
)
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
...@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__) ...@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig" _CONFIG_FOR_DOC = "EncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)
ENCODER_DECODER_START_DOCSTRING = r""" ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
...@@ -145,8 +159,36 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -145,8 +159,36 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
""" """
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel): class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
r""" r"""
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the of the base model classes of the library as encoder and another one as decoder when created with the
...@@ -566,6 +608,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ...@@ -566,6 +608,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
): ):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
decoder_processing_inputs = { decoder_processing_inputs = {
"func": self.decoder.call, "func": self.decoder.call,
"config": self.decoder.config, "config": self.decoder.config,
...@@ -574,7 +621,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ...@@ -574,7 +621,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
"encoder_hidden_states": encoder_hidden_states, "encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask, "encoder_attention_mask": attention_mask,
"inputs_embeds": decoder_inputs_embeds, "inputs_embeds": decoder_inputs_embeds,
"labels": labels,
"output_attentions": output_attentions, "output_attentions": output_attentions,
"output_hidden_states": output_hidden_states, "output_hidden_states": output_hidden_states,
"use_cache": use_cache, "use_cache": use_cache,
...@@ -592,12 +638,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ...@@ -592,12 +638,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
decoder_inputs = input_processing(**decoder_processing_inputs) decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs) decoder_outputs = self.decoder(**decoder_inputs)
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0] logits = decoder_outputs[0]
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
past_key_values = None # Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
past_key_values = None
if decoder_inputs["use_cache"]: if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2] past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs` # The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
...@@ -611,7 +662,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ...@@ -611,7 +662,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
return output return output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=decoder_outputs.loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
...@@ -693,6 +744,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel): ...@@ -693,6 +744,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
"Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported." "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported."
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import tempfile import tempfile
import warnings
from typing import Optional from typing import Optional
import tensorflow as tf import tensorflow as tf
...@@ -29,7 +30,13 @@ from ...file_utils import ( ...@@ -29,7 +30,13 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
get_initializer,
input_processing,
shape_list,
)
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
...@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__) ...@@ -40,6 +47,13 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig" _CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
DEPRECATION_WARNING = (
"Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the "
"encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if fine-tuning "
"a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the labels, no "
"need to pass them yourself anymore."
)
VISION_ENCODER_DECODER_START_DOCSTRING = r""" VISION_ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
...@@ -134,8 +148,37 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" ...@@ -134,8 +148,37 @@ VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.encoder_decoder.modeling_tf_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
if pad_token_id is None:
raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
if decoder_start_token_id is None:
raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) @add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class TFVisionEncoderDecoderModel(TFPreTrainedModel): class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
r""" r"""
[`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture [`TFVisionEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
with one of the base vision model classes of the library as encoder and another one of the base model classes as with one of the base vision model classes of the library as encoder and another one of the base model classes as
...@@ -594,6 +637,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): ...@@ -594,6 +637,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
): ):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
batch_size, sequence_length = shape_list(encoder_hidden_states)[:2] batch_size, sequence_length = shape_list(encoder_hidden_states)[:2]
encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32) encoder_attention_mask = tf.ones(shape=(batch_size, sequence_length), dtype=tf.int32)
...@@ -605,7 +653,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): ...@@ -605,7 +653,6 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
"encoder_hidden_states": encoder_hidden_states, "encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask, "encoder_attention_mask": encoder_attention_mask,
"inputs_embeds": decoder_inputs_embeds, "inputs_embeds": decoder_inputs_embeds,
"labels": labels,
"output_attentions": output_attentions, "output_attentions": output_attentions,
"output_hidden_states": output_hidden_states, "output_hidden_states": output_hidden_states,
"use_cache": use_cache, "use_cache": use_cache,
...@@ -622,12 +669,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): ...@@ -622,12 +669,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
decoder_inputs = input_processing(**decoder_processing_inputs) decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs) decoder_outputs = self.decoder(**decoder_inputs)
loss = None if decoder_inputs["labels"] is None else decoder_outputs[0] logits = decoder_outputs[0]
logits = decoder_outputs[0] if decoder_inputs["labels"] is None else decoder_outputs[1]
past_key_values = None # Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits)
past_key_values = None
if decoder_inputs["use_cache"]: if decoder_inputs["use_cache"]:
past_key_values = decoder_outputs[1] if decoder_inputs["labels"] is None else decoder_outputs[2] past_key_values = decoder_outputs[1]
# The starting index of the remaining elements in `decoder_outputs` # The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
...@@ -641,7 +693,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): ...@@ -641,7 +693,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
return output return output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=decoder_outputs.loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=past, past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
...@@ -715,6 +767,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel): ...@@ -715,6 +767,9 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError( raise NotImplementedError(
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported." "Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import copy
import os import os
import tempfile import tempfile
import unittest import unittest
...@@ -237,7 +238,7 @@ class TFEncoderDecoderMixin: ...@@ -237,7 +238,7 @@ class TFEncoderDecoderMixin:
) )
# Make sure `loss` exist # Make sure `loss` exist
assert "loss" in outputs_encoder_decoder self.assertIn("loss", outputs_encoder_decoder)
batch_size, seq_len = decoder_input_ids.shape batch_size, seq_len = decoder_input_ids.shape
expected_shape = (batch_size, seq_len, decoder_config.vocab_size) expected_shape = (batch_size, seq_len, decoder_config.vocab_size)
...@@ -319,12 +320,18 @@ class TFEncoderDecoderMixin: ...@@ -319,12 +320,18 @@ class TFEncoderDecoderMixin:
# prepare inputs # prepare inputs
tf_inputs = inputs_dict tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()} pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
tf_outputs = tf_model(**inputs_dict).to_tuple() tf_outputs = tf_model(**inputs_dict)
if "loss" in tf_outputs:
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
tf_outputs = tf_outputs.to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch") self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tf_outputs, pt_outputs): for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3) self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
...@@ -339,8 +346,12 @@ class TFEncoderDecoderMixin: ...@@ -339,8 +346,12 @@ class TFEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model. # This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple() tf_outputs_loaded = tf_model_loaded(**inputs_dict)
if "loss" in tf_outputs_loaded:
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch") self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs): for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3) self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
...@@ -435,6 +446,8 @@ class TFEncoderDecoderMixin: ...@@ -435,6 +446,8 @@ class TFEncoderDecoderMixin:
def test_pt_tf_equivalence(self): def test_pt_tf_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs() config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
# Keep only common arguments # Keep only common arguments
arg_names = [ arg_names = [
"config", "config",
...@@ -454,6 +467,9 @@ class TFEncoderDecoderMixin: ...@@ -454,6 +467,9 @@ class TFEncoderDecoderMixin:
# `encoder_hidden_states` is not used in model call/forward # `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"] del inputs_dict["encoder_hidden_states"]
inputs_dict_with_labels = copy.copy(inputs_dict)
inputs_dict_with_labels["labels"] = labels
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0] batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant( inputs_dict["decoder_attention_mask"] = tf.constant(
...@@ -471,6 +487,10 @@ class TFEncoderDecoderMixin: ...@@ -471,6 +487,10 @@ class TFEncoderDecoderMixin:
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# check equivalence with labels
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`, # This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`. # which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected # # check `enc_to_dec_proj` work as expected
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """ """ Testing suite for the TensorFlow VisionEncoderDecoder model. """
import copy
import os import os
import tempfile import tempfile
import unittest import unittest
...@@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin: ...@@ -307,12 +308,18 @@ class TFVisionEncoderDecoderMixin:
# prepare inputs # prepare inputs
tf_inputs = inputs_dict tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()} pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
if "labels" in pt_inputs:
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
tf_outputs = tf_model(**inputs_dict).to_tuple() tf_outputs = tf_model(**inputs_dict)
if "loss" in tf_outputs:
tf_outputs.loss = tf.math.reduce_mean(tf_outputs.loss)
tf_outputs = tf_outputs.to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch") self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tf_outputs, pt_outputs): for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3) self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
...@@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin: ...@@ -327,8 +334,12 @@ class TFVisionEncoderDecoderMixin:
# This is only for copying some specific attributes of this particular model. # This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple() tf_outputs_loaded = tf_model_loaded(**inputs_dict)
if "loss" in tf_outputs_loaded:
tf_outputs_loaded.loss = tf.math.reduce_mean(tf_outputs_loaded.loss)
tf_outputs_loaded = tf_outputs_loaded.to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch") self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs): for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3) self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
...@@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin: ...@@ -423,6 +434,8 @@ class TFVisionEncoderDecoderMixin:
def test_pt_tf_equivalence(self): def test_pt_tf_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs() config_inputs_dict = self.prepare_config_and_inputs()
labels = config_inputs_dict.pop("decoder_token_labels")
# Keep only common arguments # Keep only common arguments
arg_names = [ arg_names = [
"config", "config",
...@@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin: ...@@ -441,6 +454,9 @@ class TFVisionEncoderDecoderMixin:
# `encoder_hidden_states` is not used in model call/forward # `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"] del inputs_dict["encoder_hidden_states"]
inputs_dict_with_labels = copy.copy(inputs_dict)
inputs_dict_with_labels["labels"] = labels
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask) # Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0] batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant( inputs_dict["decoder_attention_mask"] = tf.constant(
...@@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin: ...@@ -458,6 +474,10 @@ class TFVisionEncoderDecoderMixin:
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict) self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict) self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# check equivalence with labels
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict_with_labels)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict_with_labels)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`, # This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`. # which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected # # check `enc_to_dec_proj` work as expected
...@@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te ...@@ -543,6 +563,7 @@ class TFViT2GPT2EncoderDecoderModelTest(TFVisionEncoderDecoderMixin, unittest.Te
"decoder_config": decoder_config, "decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"decoder_token_labels": decoder_token_labels,
"encoder_hidden_states": encoder_hidden_states, # This is not used in the tests. "encoder_hidden_states": encoder_hidden_states, # This is not used in the tests.
"labels": decoder_token_labels, "labels": decoder_token_labels,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment