common_utils.py 7.32 KB
Newer Older
1
2
import contextlib
import functools
3
import os
4
import random
5
6
import shutil
import tempfile
7
8

import numpy as np
eellison's avatar
eellison committed
9
import torch
10
from PIL import Image
11
from torchvision import io
12

13
import __main__  # noqa: 401
14

15

Philip Meier's avatar
Philip Meier committed
16
17
18
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true"
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
19
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
20
CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda."
21

22
23
24
25
26
27
28
29
30
31
32

@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
    tmp_dir = tempfile.mkdtemp(**kwargs)
    if src is not None:
        os.rmdir(tmp_dir)
        shutil.copytree(src, tmp_dir)
    try:
        yield tmp_dir
    finally:
        shutil.rmtree(tmp_dir)
eellison's avatar
eellison committed
33
34


35
36
37
38
39
def set_rng_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)


40
class MapNestedTensorObjectImpl:
eellison's avatar
eellison committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def __init__(self, tensor_map_fn):
        self.tensor_map_fn = tensor_map_fn

    def __call__(self, object):
        if isinstance(object, torch.Tensor):
            return self.tensor_map_fn(object)

        elif isinstance(object, dict):
            mapped_dict = {}
            for key, value in object.items():
                mapped_dict[self(key)] = self(value)
            return mapped_dict

        elif isinstance(object, (list, tuple)):
            mapped_iter = []
            for iter in object:
                mapped_iter.append(self(iter))
            return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)

        else:
            return object


def map_nested_tensor_object(object, tensor_map_fn):
    impl = MapNestedTensorObjectImpl(tensor_map_fn)
    return impl(object)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def is_iterable(obj):
    try:
        iter(obj)
        return True
    except TypeError:
        return False


@contextlib.contextmanager
def freeze_rng_state():
    rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        cuda_rng_state = torch.cuda.get_rng_state()
    yield
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(cuda_rng_state)
    torch.set_rng_state(rng_state)
86
87


88
def cycle_over(objs):
89
    for idx, obj1 in enumerate(objs):
90
        for obj2 in objs[:idx] + objs[idx + 1 :]:
91
            yield obj1, obj2
92
93
94


def int_dtypes():
95
    return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
96
97
98


def float_dtypes():
99
    return (torch.float32, torch.float64)
100
101
102
103
104
105
106
107


@contextlib.contextmanager
def disable_console_output():
    with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
        stack.enter_context(contextlib.redirect_stdout(devnull))
        stack.enter_context(contextlib.redirect_stderr(devnull))
        yield
108
109


110
111
def cpu_and_gpu():
    import pytest  # noqa
112
113

    return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
114
115
116
117


def needs_cuda(test_func):
    import pytest  # noqa
118

119
    return pytest.mark.needs_cuda(test_func)
Nicolas Hug's avatar
Nicolas Hug committed
120
121
122
123
124


def _create_data(height=3, width=3, channels=3, device="cpu"):
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
    tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
125
126
127
128
129
130
    data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
    mode = "RGB"
    if channels == 1:
        mode = "L"
        data = data[..., 0]
    pil_img = Image.fromarray(data, mode=mode)
Nicolas Hug's avatar
Nicolas Hug committed
131
132
133
134
135
    return tensor, pil_img


def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
136
    batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
Nicolas Hug's avatar
Nicolas Hug committed
137
138
139
    return batch_tensor


140
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
141
142


143
144
145
146
147
148
149
150
151
152
153
154
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
    names = []
    for i in range(num_videos):
        if sizes is None:
            size = 5 * (i + 1)
        else:
            size = sizes[i]
        if fps is None:
            f = 5
        else:
            f = fps[i]
        data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
155
        name = os.path.join(tmpdir, f"{i}.mp4")
156
157
158
159
160
161
        names.append(name)
        io.write_video(name, data, fps=f)

    return names


Nicolas Hug's avatar
Nicolas Hug committed
162
163
164
165
166
167
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
    np_pil_image = np.array(pil_image)
    if np_pil_image.ndim == 2:
        np_pil_image = np_pil_image[:, :, None]
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
    if msg is None:
168
        msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
169
    assert_equal(tensor.cpu(), pil_tensor, msg=msg)
Nicolas Hug's avatar
Nicolas Hug committed
170
171


172
173
174
def _assert_approx_equal_tensor_to_pil(
    tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
):
Nicolas Hug's avatar
Nicolas Hug committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    # TODO: we could just merge this into _assert_equal_tensor_to_pil
    np_pil_image = np.array(pil_image)
    if np_pil_image.ndim == 2:
        np_pil_image = np_pil_image[:, :, None]
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)

    if allowed_percentage_diff is not None:
        # Assert that less than a given %age of pixels are different
        assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff

    # error value can be mean absolute error, max abs error
    # Convert to float to avoid underflow when computing absolute difference
    tensor = tensor.to(torch.float)
    pil_tensor = pil_tensor.to(torch.float)
    err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
190
    assert err < tol, f"{err} vs {tol}"
Nicolas Hug's avatar
Nicolas Hug committed
191
192
193
194
195
196
197


def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
    transformed_batch = fn(batch_tensors, **fn_kwargs)
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
        transformed_img = fn(img_tensor, **fn_kwargs)
198
        torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
Nicolas Hug's avatar
Nicolas Hug committed
199
200
201
202
203
204

    if scripted_fn_atol >= 0:
        scripted_fn = torch.jit.script(fn)
        # scriptable function test
        s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
        torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
205
206
207


def cache(fn):
208
209
    """Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
    but this also caches exceptions.
210
211
212
    """
    sentinel = object()
    out_cache = {}
213
    exc_tb_cache = {}
214
215
216
217
218
219
220
221
222

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        key = args + tuple(kwargs.values())

        out = out_cache.get(key, sentinel)
        if out is not sentinel:
            return out

223
224
225
        exc_tb = exc_tb_cache.get(key, sentinel)
        if exc_tb is not sentinel:
            raise exc_tb[0].with_traceback(exc_tb[1])
226
227
228
229

        try:
            out = fn(*args, **kwargs)
        except Exception as exc:
230
231
232
233
            # We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
            # traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
            # the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
            exc_tb_cache[key] = exc, exc.__traceback__
234
235
236
237
238
239
            raise exc

        out_cache[key] = out
        return out

    return wrapper