Unverified Commit edb704b2 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix inverted conditional in TF common test! (#22540)

* Fix inverted conditional in TF common test!

* Make the same change in the PT tests file

* Make sure hidden states for GPT2 have the same output shape in PT/TF

* Minor fix to PT implementation of token classification loss

* Skip loss equivalence test for TFHubert because it keeps overflowing to inf

* Compute LM loss for TF the (weird) way it's computed in PT

* Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert

* Fix - don't try to access the hidden states property when output is a tuple
parent 48fbd8fa
...@@ -1228,16 +1228,7 @@ class EsmForTokenClassification(EsmPreTrainedModel): ...@@ -1228,16 +1228,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1051,6 +1051,12 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -1051,6 +1051,12 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
if return_dict and output_hidden_states:
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
else:
all_hidden_states = None
lm_logits = self.transformer.wte(hidden_states, mode="linear") lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1) mc_logits = tf.squeeze(mc_logits, axis=-1)
...@@ -1062,7 +1068,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -1062,7 +1068,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
logits=lm_logits, logits=lm_logits,
mc_logits=mc_logits, mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=all_hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
......
...@@ -953,9 +953,11 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -953,9 +953,11 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
loss = None loss = None
if labels is not None: if labels is not None:
# shift labels to the left and cut last logit token # shift labels to the left and cut last logit token
shifted_logits = lm_logits[:, :-1] labels = tf.concat(
labels = labels[:, 1:] [labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))],
loss = self.hf_compute_loss(labels, shifted_logits) axis=-1,
)
loss = self.hf_compute_loss(labels, lm_logits)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
import copy import copy
import inspect import inspect
import math import math
import os
import tempfile
import unittest import unittest
import numpy as np import numpy as np
import pytest import pytest
from transformers import is_tf_available from transformers import is_tf_available
from transformers.testing_utils import require_soundfile, require_tf, slow from transformers.testing_utils import is_pt_tf_cross_test, require_soundfile, require_tf, slow
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
...@@ -333,6 +335,62 @@ class TFHubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -333,6 +335,62 @@ class TFHubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf @require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
...@@ -458,6 +516,62 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -458,6 +516,62 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf @require_tf
class TFHubertUtilsTest(unittest.TestCase): class TFHubertUtilsTest(unittest.TestCase):
......
...@@ -19,6 +19,8 @@ import glob ...@@ -19,6 +19,8 @@ import glob
import inspect import inspect
import math import math
import multiprocessing import multiprocessing
import os
import tempfile
import traceback import traceback
import unittest import unittest
...@@ -31,6 +33,7 @@ from transformers import Wav2Vec2Config, is_tf_available ...@@ -31,6 +33,7 @@ from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger, CaptureLogger,
is_flaky, is_flaky,
is_pt_tf_cross_test,
require_librosa, require_librosa,
require_pyctcdecode, require_pyctcdecode,
require_tf, require_tf,
...@@ -397,6 +400,62 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -397,6 +400,62 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf @require_tf
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
...@@ -524,6 +583,62 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -524,6 +583,62 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf @require_tf
class TFWav2Vec2UtilsTest(unittest.TestCase): class TFWav2Vec2UtilsTest(unittest.TestCase):
......
...@@ -2030,7 +2030,7 @@ class ModelTesterMixin: ...@@ -2030,7 +2030,7 @@ class ModelTesterMixin:
# For some models (e.g. base models), there is no label returned. # For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts. # Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()): if not set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
pt_inputs_dict_with_labels = None pt_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions # Check we can load pt model in tf and vice-versa with model => model functions
......
...@@ -699,7 +699,7 @@ class TFModelTesterMixin: ...@@ -699,7 +699,7 @@ class TFModelTesterMixin:
# For some models (e.g. base models), there is no label returned. # For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts. # Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()): if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
tf_inputs_dict_with_labels = None tf_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions # Check we can load pt model in tf and vice-versa with model => model functions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment