common_utils.py 15.8 KB
Newer Older
1
2
import contextlib
import functools
3
import itertools
4
import os
5
import pathlib
6
import random
7
import re
8
import shutil
9
import sys
10
import tempfile
11
import warnings
12
from subprocess import CalledProcessError, check_output, STDOUT
13
14

import numpy as np
15
16
import PIL.Image
import pytest
eellison's avatar
eellison committed
17
import torch
18
import torch.testing
19
from PIL import Image
20

21
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
22
from torchvision import io, tv_tensors
23
from torchvision.transforms._functional_tensor import _max_value as get_max_value
24
from torchvision.transforms.v2.functional import to_image, to_pil_image
25

26

27
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
Philip Meier's avatar
Philip Meier committed
28
29
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
30
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
31
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
32
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
33

34
35
36
37
38
39
40
41
42
43
44

@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
    tmp_dir = tempfile.mkdtemp(**kwargs)
    if src is not None:
        os.rmdir(tmp_dir)
        shutil.copytree(src, tmp_dir)
    try:
        yield tmp_dir
    finally:
        shutil.rmtree(tmp_dir)
eellison's avatar
eellison committed
45
46


47
48
49
50
51
def set_rng_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)


52
class MapNestedTensorObjectImpl:
eellison's avatar
eellison committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def __init__(self, tensor_map_fn):
        self.tensor_map_fn = tensor_map_fn

    def __call__(self, object):
        if isinstance(object, torch.Tensor):
            return self.tensor_map_fn(object)

        elif isinstance(object, dict):
            mapped_dict = {}
            for key, value in object.items():
                mapped_dict[self(key)] = self(value)
            return mapped_dict

        elif isinstance(object, (list, tuple)):
            mapped_iter = []
            for iter in object:
                mapped_iter.append(self(iter))
            return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)

        else:
            return object


def map_nested_tensor_object(object, tensor_map_fn):
    impl = MapNestedTensorObjectImpl(tensor_map_fn)
    return impl(object)


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def is_iterable(obj):
    try:
        iter(obj)
        return True
    except TypeError:
        return False


@contextlib.contextmanager
def freeze_rng_state():
    rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        cuda_rng_state = torch.cuda.get_rng_state()
    yield
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(cuda_rng_state)
    torch.set_rng_state(rng_state)
98
99


100
def cycle_over(objs):
101
    for idx, obj1 in enumerate(objs):
102
        for obj2 in objs[:idx] + objs[idx + 1 :]:
103
            yield obj1, obj2
104
105
106


def int_dtypes():
107
    return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
108
109
110


def float_dtypes():
111
    return (torch.float32, torch.float64)
112
113
114
115
116
117
118
119


@contextlib.contextmanager
def disable_console_output():
    with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
        stack.enter_context(contextlib.redirect_stdout(devnull))
        stack.enter_context(contextlib.redirect_stderr(devnull))
        yield
120
121


122
def cpu_and_cuda():
123
    import pytest  # noqa
124
125

    return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
126
127


128
129
130
131
def cpu_and_cuda_and_mps():
    return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)


132
133
def needs_cuda(test_func):
    import pytest  # noqa
134

135
    return pytest.mark.needs_cuda(test_func)
Nicolas Hug's avatar
Nicolas Hug committed
136
137


138
139
140
141
142
143
def needs_mps(test_func):
    import pytest  # noqa

    return pytest.mark.needs_mps(test_func)


Nicolas Hug's avatar
Nicolas Hug committed
144
145
146
def _create_data(height=3, width=3, channels=3, device="cpu"):
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
    tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
147
148
149
150
151
152
    data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
    mode = "RGB"
    if channels == 1:
        mode = "L"
        data = data[..., 0]
    pil_img = Image.fromarray(data, mode=mode)
Nicolas Hug's avatar
Nicolas Hug committed
153
154
155
156
157
    return tensor, pil_img


def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
158
    batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
Nicolas Hug's avatar
Nicolas Hug committed
159
160
161
    return batch_tensor


162
163
164
165
166
167
168
169
170
171
172
173
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
    names = []
    for i in range(num_videos):
        if sizes is None:
            size = 5 * (i + 1)
        else:
            size = sizes[i]
        if fps is None:
            f = 5
        else:
            f = fps[i]
        data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
174
        name = os.path.join(tmpdir, f"{i}.mp4")
175
176
177
178
179
180
        names.append(name)
        io.write_video(name, data, fps=f)

    return names


Nicolas Hug's avatar
Nicolas Hug committed
181
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
182
    # FIXME: this is handled automatically by `assert_equal` below. Let's remove this in favor of it
Nicolas Hug's avatar
Nicolas Hug committed
183
184
185
186
187
    np_pil_image = np.array(pil_image)
    if np_pil_image.ndim == 2:
        np_pil_image = np_pil_image[:, :, None]
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
    if msg is None:
188
        msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
189
    assert_equal(tensor.cpu(), pil_tensor, msg=msg)
Nicolas Hug's avatar
Nicolas Hug committed
190
191


192
193
194
def _assert_approx_equal_tensor_to_pil(
    tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
):
195
    # FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
Nicolas Hug's avatar
Nicolas Hug committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    # TODO: we could just merge this into _assert_equal_tensor_to_pil
    np_pil_image = np.array(pil_image)
    if np_pil_image.ndim == 2:
        np_pil_image = np_pil_image[:, :, None]
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)

    if allowed_percentage_diff is not None:
        # Assert that less than a given %age of pixels are different
        assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff

    # error value can be mean absolute error, max abs error
    # Convert to float to avoid underflow when computing absolute difference
    tensor = tensor.to(torch.float)
    pil_tensor = pil_tensor.to(torch.float)
    err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
211
    assert err < tol, f"{err} vs {tol}"
Nicolas Hug's avatar
Nicolas Hug committed
212
213
214
215
216
217
218


def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
    transformed_batch = fn(batch_tensors, **fn_kwargs)
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
        transformed_img = fn(img_tensor, **fn_kwargs)
219
        torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
Nicolas Hug's avatar
Nicolas Hug committed
220
221
222
223
224
225

    if scripted_fn_atol >= 0:
        scripted_fn = torch.jit.script(fn)
        # scriptable function test
        s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
        torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
226
227
228


def cache(fn):
229
230
    """Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
    but this also caches exceptions.
231
232
233
    """
    sentinel = object()
    out_cache = {}
234
    exc_tb_cache = {}
235
236
237
238
239
240
241
242
243

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        key = args + tuple(kwargs.values())

        out = out_cache.get(key, sentinel)
        if out is not sentinel:
            return out

244
245
246
        exc_tb = exc_tb_cache.get(key, sentinel)
        if exc_tb is not sentinel:
            raise exc_tb[0].with_traceback(exc_tb[1])
247
248
249
250

        try:
            out = fn(*args, **kwargs)
        except Exception as exc:
251
252
253
254
            # We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
            # traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
            # the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
            exc_tb_cache[key] = exc, exc.__traceback__
255
256
257
258
259
260
            raise exc

        out_cache[key] = out
        return out

    return wrapper
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289


def combinations_grid(**kwargs):
    """Creates a grid of input combinations.

    Each element in the returned sequence is a dictionary containing one possible combination as values.

    Example:
        >>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
        [
            {'foo': 'bar', 'spam': 'eggs'},
            {'foo': 'bar', 'spam': 'ham'},
            {'foo': 'baz', 'spam': 'eggs'},
            {'foo': 'baz', 'spam': 'ham'}
        ]
    """
    return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


class ImagePair(TensorLikePair):
    def __init__(
        self,
        actual,
        expected,
        *,
        mae=False,
        **other_parameters,
    ):
        if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
290
            actual, expected = [to_image(input) for input in [actual, expected]]
291
292
293
294
295
296
297
298
299
300
301

        super().__init__(actual, expected, **other_parameters)
        self.mae = mae

    def compare(self) -> None:
        actual, expected = self.actual, self.expected

        self._compare_attributes(actual, expected)
        actual, expected = self._equalize_attributes(actual, expected)

        if self.mae:
302
303
            if actual.dtype is torch.uint8:
                actual, expected = actual.to(torch.int), expected.to(torch.int)
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            mae = float(torch.abs(actual - expected).float().mean())
            if mae > self.atol:
                self._fail(
                    AssertionError,
                    f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
                )
        else:
            super()._compare_values(actual, expected)


def assert_close(
    actual,
    expected,
    *,
    allow_subclasses=True,
    rtol=None,
    atol=None,
    equal_nan=False,
    check_device=True,
    check_dtype=True,
    check_layout=True,
    check_stride=False,
    msg=None,
    **kwargs,
):
    """Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
    __tracebackhide__ = True

    error_metas = not_close_error_metas(
        actual,
        expected,
        pair_types=(
            NonePair,
            BooleanPair,
            NumberPair,
            ImagePair,
            TensorLikePair,
        ),
        allow_subclasses=allow_subclasses,
        rtol=rtol,
        atol=atol,
        equal_nan=equal_nan,
        check_device=check_device,
        check_dtype=check_dtype,
        check_layout=check_layout,
        check_stride=check_stride,
        **kwargs,
    )

    if error_metas:
        raise error_metas[0].to_error(msg)


assert_equal = functools.partial(assert_close, rtol=0, atol=0)


360
DEFAULT_SIZE = (17, 11)
361

362
363
364
365
366
367
368
369
370

NUM_CHANNELS_MAP = {
    "GRAY": 1,
    "GRAY_ALPHA": 2,
    "RGB": 3,
    "RGBA": 4,
}


371
372
373
374
375
376
377
378
379
def make_image(
    size=DEFAULT_SIZE,
    *,
    color_space="RGB",
    batch_dims=(),
    dtype=None,
    device="cpu",
    memory_format=torch.contiguous_format,
):
380
    num_channels = NUM_CHANNELS_MAP[color_space]
Philip Meier's avatar
Philip Meier committed
381
    dtype = dtype or torch.uint8
382
383
    max_value = get_max_value(dtype)
    data = torch.testing.make_tensor(
384
        (*batch_dims, num_channels, *size),
385
386
        low=0,
        high=max_value,
Philip Meier's avatar
Philip Meier committed
387
        dtype=dtype,
388
389
390
391
392
393
        device=device,
        memory_format=memory_format,
    )
    if color_space in {"GRAY_ALPHA", "RGBA"}:
        data[..., -1, :, :] = max_value

394
    return tv_tensors.Image(data)
395
396
397
398
399
400
401


def make_image_tensor(*args, **kwargs):
    return make_image(*args, **kwargs).as_subclass(torch.Tensor)


def make_image_pil(*args, **kwargs):
402
    return to_pil_image(make_image(*args, **kwargs))
403
404


405
def make_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
406
    canvas_size=DEFAULT_SIZE,
407
    *,
408
    format=tv_tensors.BoundingBoxFormat.XYXY,
409
410
411
412
413
414
    dtype=None,
    device="cpu",
):
    def sample_position(values, max_value):
        # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
        # However, if we have batch_dims, we need tensors as limits.
415
        return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
416
417

    if isinstance(format, str):
418
        format = tv_tensors.BoundingBoxFormat[format]
419
420
421

    dtype = dtype or torch.float32

422
423
    num_objects = 1
    h, w = [torch.randint(1, c, (num_objects,)) for c in canvas_size]
Philip Meier's avatar
Philip Meier committed
424
425
    y = sample_position(h, canvas_size[0])
    x = sample_position(w, canvas_size[1])
426

427
    if format is tv_tensors.BoundingBoxFormat.XYWH:
428
        parts = (x, y, w, h)
429
    elif format is tv_tensors.BoundingBoxFormat.XYXY:
430
431
432
433
        x1, y1 = x, y
        x2 = x1 + w
        y2 = y1 + h
        parts = (x1, y1, x2, y2)
434
    elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
435
436
437
438
439
440
        cx = x + w / 2
        cy = y + h / 2
        parts = (cx, cy, w, h)
    else:
        raise ValueError(f"Format {format} is not supported")

441
    return tv_tensors.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
442
        torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
443
444
445
    )


446
def make_detection_mask(size=DEFAULT_SIZE, *, dtype=None, device="cpu"):
447
    """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
448
    num_objects = 1
449
    return tv_tensors.Mask(
450
        torch.testing.make_tensor(
451
            (num_objects, *size),
452
453
454
455
456
457
458
459
460
461
            low=0,
            high=2,
            dtype=dtype or torch.bool,
            device=device,
        )
    )


def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
    """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
462
    return tv_tensors.Mask(
463
464
465
466
467
468
469
470
        torch.testing.make_tensor(
            (*batch_dims, *size),
            low=0,
            high=num_categories,
            dtype=dtype or torch.uint8,
            device=device,
        )
    )
471
472


473
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
474
    return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
475
476


477
478
479
480
def make_video_tensor(*args, **kwargs):
    return make_video(*args, **kwargs).as_subclass(torch.Tensor)


481
482
def assert_run_python_script(source_code):
    """Utility to check assertions in an independent Python subprocess.
483

484
    The script provided in the source code should return 0 and not print
485
486
487
488
    anything on stderr or stdout. Modified from scikit-learn test utils.

    Args:
        source_code (str): The Python source code to execute.
489
    """
490
491
492
493
    with get_tmp_dir() as root:
        path = pathlib.Path(root) / "main.py"
        with open(path, "w") as file:
            file.write(source_code)
494
495

        try:
496
            out = check_output([sys.executable, str(path)], stderr=STDOUT)
497
498
499
500
        except CalledProcessError as e:
            raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
        if out != b"":
            raise AssertionError(out.decode())
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520


@contextlib.contextmanager
def assert_no_warnings():
    # The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
    # the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
    with warnings.catch_warnings():
        warnings.simplefilter("error")
        yield


@contextlib.contextmanager
def ignore_jit_no_profile_information_warning():
    # Calling a scripted object often triggers a warning like
    # `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
    # with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
    # them.
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
        yield