Unverified Commit edb704b2 authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix inverted conditional in TF common test! (#22540)

* Fix inverted conditional in TF common test!

* Make the same change in the PT tests file

* Make sure hidden states for GPT2 have the same output shape in PT/TF

* Minor fix to PT implementation of token classification loss

* Skip loss equivalence test for TFHubert because it keeps overflowing to inf

* Compute LM loss for TF the (weird) way it's computed in PT

* Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert

* Fix - don't try to access the hidden states property when output is a tuple
parent 48fbd8fa
......@@ -1228,16 +1228,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
......
......@@ -1051,6 +1051,12 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
if return_dict and output_hidden_states:
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
else:
all_hidden_states = None
lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1)
......@@ -1062,7 +1068,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
hidden_states=all_hidden_states,
attentions=transformer_outputs.attentions,
)
......
......@@ -953,9 +953,11 @@ class TFXGLMForCausalLM(TFXGLMPreTrainedModel, TFCausalLanguageModelingLoss):
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
shifted_logits = lm_logits[:, :-1]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels, shifted_logits)
labels = tf.concat(
[labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))],
axis=-1,
)
loss = self.hf_compute_loss(labels, lm_logits)
if not return_dict:
output = (lm_logits,) + outputs[1:]
......
......@@ -17,13 +17,15 @@
import copy
import inspect
import math
import os
import tempfile
import unittest
import numpy as np
import pytest
from transformers import is_tf_available
from transformers.testing_utils import require_soundfile, require_tf, slow
from transformers.testing_utils import is_pt_tf_cross_test, require_soundfile, require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
......@@ -333,6 +335,62 @@ class TFHubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
......@@ -458,6 +516,62 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf
class TFHubertUtilsTest(unittest.TestCase):
......
......@@ -19,6 +19,8 @@ import glob
import inspect
import math
import multiprocessing
import os
import tempfile
import traceback
import unittest
......@@ -31,6 +33,7 @@ from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
is_pt_tf_cross_test,
require_librosa,
require_pyctcdecode,
require_tf,
......@@ -397,6 +400,62 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
......@@ -524,6 +583,62 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch
import transformers
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)
# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
@require_tf
class TFWav2Vec2UtilsTest(unittest.TestCase):
......
......@@ -2030,7 +2030,7 @@ class ModelTesterMixin:
# For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
if not set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
pt_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions
......
......@@ -699,7 +699,7 @@ class TFModelTesterMixin:
# For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
tf_inputs_dict_with_labels = None
# Check we can load pt model in tf and vice-versa with model => model functions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment