Unverified Commit 13deb95a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Move tests/utils.py -> transformers/testing_utils.py (#5350)

parent 9c219305
......@@ -12,7 +12,7 @@ export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py and utils.py
# Add parent directory to python path to access lightning_base.py and testing_utils.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=cnn_tiny/ \
......
......@@ -12,6 +12,7 @@ import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers.testing_utils import require_multigpu
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
......@@ -107,7 +108,7 @@ class TestSummarizationDistiller(unittest.TestCase):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
@require_multigpu
def test_multigpu(self):
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
self._test_distiller_cli(updates)
......
import unittest
from transformers import is_torch_available
from .utils import require_torch
from transformers.testing_utils import require_torch
if is_torch_available():
......
......@@ -4,8 +4,7 @@ import unittest
from pathlib import Path
from transformers import AutoConfig, is_torch_available
from .utils import require_torch, torch_device
from transformers.testing_utils import require_torch, torch_device
if is_torch_available():
......
......@@ -4,8 +4,7 @@ import unittest
from pathlib import Path
from transformers import AutoConfig, is_tf_available
from .utils import require_tf
from transformers.testing_utils import require_tf
if is_tf_available():
......
......@@ -19,8 +19,7 @@ import unittest
from transformers.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig
from .utils import DUMMY_UNKWOWN_IDENTIFIER
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
......
......@@ -21,8 +21,7 @@ from pathlib import Path
from typing import List, Union
import transformers
from .utils import require_tf, require_torch, slow
from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger()
......
......@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -17,8 +17,7 @@
import unittest
from transformers import is_torch_available
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
if is_torch_available():
......
......@@ -20,10 +20,10 @@ import timeout_decorator # noqa
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -16,8 +16,7 @@
import unittest
from transformers import is_torch_available
from .utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -21,8 +21,7 @@ import unittest
from typing import List
from transformers import is_torch_available
from .utils import require_multigpu, require_torch, slow, torch_device
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():
......
......@@ -16,10 +16,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, torch_device
if is_torch_available():
......
......@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -18,12 +18,12 @@ import tempfile
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
......@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment