utils.py 3.26 KB
Newer Older
1
import os
Aymeric Augustin's avatar
Aymeric Augustin committed
2
import unittest
3
4
5
6
7
from distutils.util import strtobool

from transformers.file_utils import _tf_available, _torch_available


Julien Chaumond's avatar
Julien Chaumond committed
8
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
Julien Chaumond's avatar
Julien Chaumond committed
9
10
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
Julien Chaumond's avatar
Julien Chaumond committed
11
12


13
def parse_flag_from_env(key, default=False):
14
    try:
15
16
17
18
19
20
21
22
23
24
25
26
27
        value = os.environ[key]
    except KeyError:
        # KEY isn't set, default to `default`.
        _value = default
    else:
        # KEY is set, convert it to True or False.
        try:
            _value = strtobool(value)
        except ValueError:
            # More values are supported, but let's keep the message simple.
            raise ValueError("If set, {} must be yes or no.".format(key))
    return _value

28

Julien Chaumond's avatar
Julien Chaumond committed
29
30
31
32
33
34
35
36
37
38
39
40
41
def parse_int_from_env(key, default=None):
    try:
        value = os.environ[key]
    except KeyError:
        _value = default
    else:
        try:
            _value = int(value)
        except ValueError:
            raise ValueError("If set, {} must be a int.".format(key))
    return _value


42
43
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
Julien Chaumond's avatar
Julien Chaumond committed
44
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59


def slow(test_case):
    """
    Decorator marking a test as slow.

    Slow tests are skipped by default. Set the RUN_SLOW environment variable
    to a truthy value to run them.

    """
    if not _run_slow_tests:
        test_case = unittest.skip("test is slow")(test_case)
    return test_case


60
61
62
63
64
65
66
67
68
69
70
71
72
def custom_tokenizers(test_case):
    """
    Decorator marking a test for a custom tokenizer.

    Custom tokenizers require additional dependencies, and are skipped
    by default. Set the RUN_CUSTOM_TOKENIZERS environment variable
    to a truthy value to run them.
    """
    if not _run_custom_tokenizers:
        test_case = unittest.skip("test of custom tokenizers")(test_case)
    return test_case


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def require_torch(test_case):
    """
    Decorator marking a test that requires PyTorch.

    These tests are skipped when PyTorch isn't installed.

    """
    if not _torch_available:
        test_case = unittest.skip("test requires PyTorch")(test_case)
    return test_case


def require_tf(test_case):
    """
    Decorator marking a test that requires TensorFlow.

    These tests are skipped when TensorFlow isn't installed.

    """
    if not _tf_available:
        test_case = unittest.skip("test requires TensorFlow")(test_case)
    return test_case


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def require_multigpu(test_case):
    """
    Decorator marking a test that requires a multi-GPU setup (in PyTorch).

    These tests are skipped on a machine without multiple GPUs.

    To run *only* the multigpu tests, assuming all test names contain multigpu:
    $ pytest -sv ./tests -k "multigpu"
    """
    if not _torch_available:
        return unittest.skip("test requires PyTorch")(test_case)

    import torch

    if torch.cuda.device_count() < 2:
        return unittest.skip("test requires multiple GPUs")(test_case)
    return test_case


116
117
if _torch_available:
    # Set the USE_CUDA environment variable to select a GPU.
118
    torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
119
120
else:
    torch_device = None