Commit 345c23a6 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Replace (TF)CommonTestCases for modeling with a mixin.

I suspect the wrapper classes were created in order to prevent the
abstract base class (TF)CommonModelTester from being included in test
discovery and running, because that would fail.

I solved this by replacing the abstract base class with a mixin.

Code changes are just de-indenting and automatic reformattings
performed by black to use the extra line space.
parent 7e98e211
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import random import random
import unittest
from transformers import TransfoXLConfig, is_tf_available from transformers import TransfoXLConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -33,7 +34,7 @@ if is_tf_available(): ...@@ -33,7 +34,7 @@ if is_tf_available():
@require_tf @require_tf
class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else () all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
test_pruning = False test_pruning = False
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_tf_available from transformers import is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -34,7 +36,7 @@ if is_tf_available(): ...@@ -34,7 +36,7 @@ if is_tf_available():
@require_tf @require_tf
class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple) (TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple)
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import random import random
import unittest
from transformers import XLNetConfig, is_tf_available from transformers import XLNetConfig, is_tf_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -37,7 +38,7 @@ if is_tf_available(): ...@@ -37,7 +38,7 @@ if is_tf_available():
@require_tf @require_tf
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import random import random
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -30,7 +31,7 @@ if is_torch_available(): ...@@ -30,7 +31,7 @@ if is_torch_available():
@require_torch @require_torch
class TransfoXLModelTest(CommonTestCases.CommonModelTester): class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
test_pruning = False test_pruning = False
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -34,7 +36,7 @@ if is_torch_available(): ...@@ -34,7 +36,7 @@ if is_torch_available():
@require_torch @require_torch
class XLMModelTest(CommonTestCases.CommonModelTester): class XLMModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import random import random
import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import CommonTestCases, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
...@@ -38,7 +39,7 @@ if is_torch_available(): ...@@ -38,7 +39,7 @@ if is_torch_available():
@require_torch @require_torch
class XLNetModelTest(CommonTestCases.CommonModelTester): class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment