Unverified Commit 13deb95a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Move tests/utils.py -> transformers/testing_utils.py (#5350)

parent 9c219305
...@@ -12,7 +12,7 @@ export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} ...@@ -12,7 +12,7 @@ export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist # Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py and utils.py # Add parent directory to python path to access lightning_base.py and testing_utils.py
export PYTHONPATH="../":"${PYTHONPATH}" export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \ python finetune.py \
--data_dir=cnn_tiny/ \ --data_dir=cnn_tiny/ \
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint from .distillation import distill_main, evaluate_checkpoint
from .finetune import main from .finetune import main
...@@ -107,7 +108,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -107,7 +108,7 @@ class TestSummarizationDistiller(unittest.TestCase):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test") @require_multigpu
def test_multigpu(self): def test_multigpu(self):
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,) updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
......
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch
from .utils import require_torch
if is_torch_available(): if is_torch_available():
......
...@@ -4,8 +4,7 @@ import unittest ...@@ -4,8 +4,7 @@ import unittest
from pathlib import Path from pathlib import Path
from transformers import AutoConfig, is_torch_available from transformers import AutoConfig, is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .utils import require_torch, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -4,8 +4,7 @@ import unittest ...@@ -4,8 +4,7 @@ import unittest
from pathlib import Path from pathlib import Path
from transformers import AutoConfig, is_tf_available from transformers import AutoConfig, is_tf_available
from transformers.testing_utils import require_tf
from .utils import require_tf
if is_tf_available(): if is_tf_available():
......
...@@ -19,8 +19,7 @@ import unittest ...@@ -19,8 +19,7 @@ import unittest
from transformers.configuration_auto import CONFIG_MAPPING, AutoConfig from transformers.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.configuration_bert import BertConfig from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig from transformers.configuration_roberta import RobertaConfig
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
from .utils import DUMMY_UNKWOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
......
...@@ -21,8 +21,7 @@ from pathlib import Path ...@@ -21,8 +21,7 @@ from pathlib import Path
from typing import List, Union from typing import List, Union
import transformers import transformers
from transformers.testing_utils import require_tf, require_torch, slow
from .utils import require_tf, require_torch, slow
logger = logging.getLogger() logger = logging.getLogger()
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
if is_torch_available(): if is_torch_available():
......
...@@ -20,10 +20,10 @@ import timeout_decorator # noqa ...@@ -20,10 +20,10 @@ import timeout_decorator # noqa
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -21,8 +21,7 @@ import unittest ...@@ -21,8 +21,7 @@ import unittest
from typing import List from typing import List
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -18,12 +18,12 @@ import tempfile ...@@ -18,12 +18,12 @@ import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented # TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTester from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor from .test_modeling_common import ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment