"...linux/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "7332dfab048e65baa0bf43b6c2b467fa48be97f3"
Commit 345c23a6 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Replace (TF)CommonTestCases for modeling with a mixin.

I suspect the wrapper classes were created in order to prevent the
abstract base class (TF)CommonModelTester from being included in test
discovery and running, because that would fail.

I solved this by replacing the abstract base class with a mixin.

Code changes are just de-indenting and automatic reformattings
performed by black to use the extra line space.
parent 7e98e211
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import XxxConfig, is_tf_available from transformers import XxxConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -32,7 +34,7 @@ if is_tf_available(): ...@@ -32,7 +34,7 @@ if is_tf_available():
@require_tf @require_tf
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -34,7 +36,7 @@ if is_torch_available(): ...@@ -34,7 +36,7 @@ if is_torch_available():
@require_torch @require_torch
class XxxModelTest(CommonTestCases.CommonModelTester): class XxxModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification) (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -33,7 +35,7 @@ if is_torch_available(): ...@@ -33,7 +35,7 @@ if is_torch_available():
@require_torch @require_torch
class AlbertModelTest(CommonTestCases.CommonModelTester): class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else () all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, floats_tensor, ids_tensor from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -37,7 +39,7 @@ if is_torch_available(): ...@@ -37,7 +39,7 @@ if is_torch_available():
@require_torch @require_torch
class BertModelTest(CommonTestCases.CommonModelTester): class BertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
...@@ -69,9 +69,8 @@ def _config_zero_init(config): ...@@ -69,9 +69,8 @@ def _config_zero_init(config):
return configs_no_init return configs_no_init
class CommonTestCases: @require_torch
@require_torch class ModelTesterMixin:
class CommonModelTester(unittest.TestCase):
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
...@@ -612,7 +611,8 @@ class CommonTestCases: ...@@ -612,7 +611,8 @@ class CommonTestCases:
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
class GPTModelTester(CommonModelTester):
class GPTModelTester(ModelTesterMixin):
def __init__( def __init__(
self, self,
parent, parent,
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -25,7 +27,7 @@ if is_torch_available(): ...@@ -25,7 +27,7 @@ if is_torch_available():
@require_torch @require_torch
class CTRLModelTest(CommonTestCases.CommonModelTester): class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
test_pruning = False test_pruning = False
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, torch_device from .utils import require_torch, torch_device
...@@ -33,7 +35,7 @@ if is_torch_available(): ...@@ -33,7 +35,7 @@ if is_torch_available():
@require_torch @require_torch
class DistilBertModelTest(CommonTestCases.CommonModelTester): class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification) (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -32,7 +34,7 @@ if is_torch_available(): ...@@ -32,7 +34,7 @@ if is_torch_available():
@require_torch @require_torch
class GPT2ModelTest(CommonTestCases.CommonModelTester): class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -32,7 +34,7 @@ if is_torch_available(): ...@@ -32,7 +34,7 @@ if is_torch_available():
@require_torch @require_torch
class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else () (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -37,7 +37,7 @@ if is_torch_available(): ...@@ -37,7 +37,7 @@ if is_torch_available():
@require_torch @require_torch
class RobertaModelTest(CommonTestCases.CommonModelTester): class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else () all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow from .utils import CACHE_DIR, require_torch, slow
...@@ -27,7 +29,7 @@ if is_torch_available(): ...@@ -27,7 +29,7 @@ if is_torch_available():
@require_torch @require_torch
class T5ModelTest(CommonTestCases.CommonModelTester): class T5ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else () all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
test_pruning = False test_pruning = False
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import AlbertConfig, is_tf_available from transformers import AlbertConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -31,7 +33,7 @@ if is_tf_available(): ...@@ -31,7 +33,7 @@ if is_tf_available():
@require_tf @require_tf
class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification) if is_tf_available() else () (TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification) if is_tf_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import BertConfig, is_tf_available from transformers import BertConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -36,7 +38,7 @@ if is_tf_available(): ...@@ -36,7 +38,7 @@ if is_tf_available():
@require_tf @require_tf
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
...@@ -20,7 +20,6 @@ import random ...@@ -20,7 +20,6 @@ import random
import shutil import shutil
import sys import sys
import tempfile import tempfile
import unittest
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
...@@ -59,9 +58,8 @@ def _config_zero_init(config): ...@@ -59,9 +58,8 @@ def _config_zero_init(config):
return configs_no_init return configs_no_init
class TFCommonTestCases: @require_tf
@require_tf class TFModelTesterMixin:
class TFCommonModelTester(unittest.TestCase):
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
...@@ -168,12 +166,8 @@ class TFCommonTestCases: ...@@ -168,12 +166,8 @@ class TFCommonTestCases:
if self.is_encoder_decoder: if self.is_encoder_decoder:
input_ids = { input_ids = {
"decoder_input_ids": tf.keras.Input( "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32" "encoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"),
),
"encoder_input_ids": tf.keras.Input(
batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"
),
} }
else: else:
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32") input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
...@@ -209,9 +203,7 @@ class TFCommonTestCases: ...@@ -209,9 +203,7 @@ class TFCommonTestCases:
outputs_dict = model(inputs_dict) outputs_dict = model(inputs_dict)
inputs_keywords = copy.deepcopy(inputs_dict) inputs_keywords = copy.deepcopy(inputs_dict)
input_ids = inputs_keywords.pop( input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None)
"input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None
)
outputs_keywords = model(input_ids, **inputs_keywords) outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy() output_dict = outputs_dict[0].numpy()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import CTRLConfig, is_tf_available from transformers import CTRLConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -26,7 +28,7 @@ if is_tf_available(): ...@@ -26,7 +28,7 @@ if is_tf_available():
@require_tf @require_tf
class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else () all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import DistilBertConfig, is_tf_available from transformers import DistilBertConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf from .utils import require_tf
...@@ -31,7 +33,7 @@ if is_tf_available(): ...@@ -31,7 +33,7 @@ if is_tf_available():
@require_tf @require_tf
class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import GPT2Config, is_tf_available from transformers import GPT2Config, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -32,7 +34,7 @@ if is_tf_available(): ...@@ -32,7 +34,7 @@ if is_tf_available():
@require_tf @require_tf
class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else () all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else () # all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import OpenAIGPTConfig, is_tf_available from transformers import OpenAIGPTConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -32,7 +34,7 @@ if is_tf_available(): ...@@ -32,7 +34,7 @@ if is_tf_available():
@require_tf @require_tf
class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester): class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else () (TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from transformers import RobertaConfig, is_tf_available from transformers import RobertaConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -36,7 +36,7 @@ if is_tf_available(): ...@@ -36,7 +36,7 @@ if is_tf_available():
@require_tf @require_tf
class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification) if is_tf_available() else () (TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification) if is_tf_available() else ()
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import T5Config, is_tf_available from transformers import T5Config, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -26,7 +28,7 @@ if is_tf_available(): ...@@ -26,7 +28,7 @@ if is_tf_available():
@require_tf @require_tf
class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = True is_encoder_decoder = True
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else () all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment