common_utils.py 15.8 KB
Newer Older
limm's avatar
limm committed
1
2
3
import contextlib
import functools
import itertools
4
import os
limm's avatar
limm committed
5
6
7
import pathlib
import random
import re
8
import shutil
eellison's avatar
eellison committed
9
import sys
limm's avatar
limm committed
10
import tempfile
11
import warnings
limm's avatar
limm committed
12
from subprocess import CalledProcessError, check_output, STDOUT
13

14
import numpy as np
limm's avatar
limm committed
15
16
17
18
import PIL.Image
import pytest
import torch
import torch.testing
19
20
from PIL import Image

limm's avatar
limm committed
21
22
23
24
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
25

limm's avatar
limm committed
26
27

IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
28
29
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
limm's avatar
limm committed
30
31
32
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
33

34
35
36
37
38
39
40
41
42
43
44

@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
    tmp_dir = tempfile.mkdtemp(**kwargs)
    if src is not None:
        os.rmdir(tmp_dir)
        shutil.copytree(src, tmp_dir)
    try:
        yield tmp_dir
    finally:
        shutil.rmtree(tmp_dir)
eellison's avatar
eellison committed
45
46


47
48
49
50
def set_rng_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)

51

limm's avatar
limm committed
52
class MapNestedTensorObjectImpl:
eellison's avatar
eellison committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def __init__(self, tensor_map_fn):
        self.tensor_map_fn = tensor_map_fn

    def __call__(self, object):
        if isinstance(object, torch.Tensor):
            return self.tensor_map_fn(object)

        elif isinstance(object, dict):
            mapped_dict = {}
            for key, value in object.items():
                mapped_dict[self(key)] = self(value)
            return mapped_dict

        elif isinstance(object, (list, tuple)):
            mapped_iter = []
            for iter in object:
                mapped_iter.append(self(iter))
            return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter)

        else:
            return object


def map_nested_tensor_object(object, tensor_map_fn):
    impl = MapNestedTensorObjectImpl(tensor_map_fn)
    return impl(object)


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def is_iterable(obj):
    try:
        iter(obj)
        return True
    except TypeError:
        return False


@contextlib.contextmanager
def freeze_rng_state():
    rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        cuda_rng_state = torch.cuda.get_rng_state()
    yield
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(cuda_rng_state)
    torch.set_rng_state(rng_state)
98
99


100
def cycle_over(objs):
limm's avatar
limm committed
101
102
103
    for idx, obj1 in enumerate(objs):
        for obj2 in objs[:idx] + objs[idx + 1 :]:
            yield obj1, obj2
104
105
106


def int_dtypes():
limm's avatar
limm committed
107
    return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
108
109
110


def float_dtypes():
limm's avatar
limm committed
111
    return (torch.float32, torch.float64)
112
113
114
115
116
117
118
119


@contextlib.contextmanager
def disable_console_output():
    with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
        stack.enter_context(contextlib.redirect_stdout(devnull))
        stack.enter_context(contextlib.redirect_stderr(devnull))
        yield
120
121


limm's avatar
limm committed
122
123
124
125
def cpu_and_cuda():
    import pytest  # noqa

    return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
126

127

limm's avatar
limm committed
128
129
def cpu_and_cuda_and_mps():
    return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
130

limm's avatar
limm committed
131
132

def needs_cuda(test_func):
133
134
    import pytest  # noqa

limm's avatar
limm committed
135
    return pytest.mark.needs_cuda(test_func)
136
137


limm's avatar
limm committed
138
139
def needs_mps(test_func):
    import pytest  # noqa
140

limm's avatar
limm committed
141
    return pytest.mark.needs_mps(test_func)
142
143


limm's avatar
limm committed
144
145
146
147
148
149
150
151
152
153
def _create_data(height=3, width=3, channels=3, device="cpu"):
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
    tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
    data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
    mode = "RGB"
    if channels == 1:
        mode = "L"
        data = data[..., 0]
    pil_img = Image.fromarray(data, mode=mode)
    return tensor, pil_img
154

155

limm's avatar
limm committed
156
157
158
159
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
    # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
    batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
    return batch_tensor
160
161


limm's avatar
limm committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
    names = []
    for i in range(num_videos):
        if sizes is None:
            size = 5 * (i + 1)
        else:
            size = sizes[i]
        if fps is None:
            f = 5
        else:
            f = fps[i]
        data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
        name = os.path.join(tmpdir, f"{i}.mp4")
        names.append(name)
        io.write_video(name, data, fps=f)

    return names


def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
    # FIXME: this is handled automatically by `assert_equal` below. Let's remove this in favor of it
    np_pil_image = np.array(pil_image)
    if np_pil_image.ndim == 2:
        np_pil_image = np_pil_image[:, :, None]
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
    if msg is None:
        msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
    assert_equal(tensor.cpu(), pil_tensor, msg=msg)


def _assert_approx_equal_tensor_to_pil(
    tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
):
    # FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
    # TODO: we could just merge this into _assert_equal_tensor_to_pil
    np_pil_image = np.array(pil_image)
    if np_pil_image.ndim == 2:
        np_pil_image = np_pil_image[:, :, None]
    pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)

    if allowed_percentage_diff is not None:
        # Assert that less than a given %age of pixels are different
        assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff

    # error value can be mean absolute error, max abs error
    # Convert to float to avoid underflow when computing absolute difference
    tensor = tensor.to(torch.float)
    pil_tensor = pil_tensor.to(torch.float)
    err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
    assert err < tol, f"{err} vs {tol}"


def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
    transformed_batch = fn(batch_tensors, **fn_kwargs)
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
        transformed_img = fn(img_tensor, **fn_kwargs)
        torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)

    if scripted_fn_atol >= 0:
        scripted_fn = torch.jit.script(fn)
        # scriptable function test
        s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
        torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)


def cache(fn):
    """Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
    but this also caches exceptions.
    """
    sentinel = object()
    out_cache = {}
    exc_tb_cache = {}

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
        key = args + tuple(kwargs.values())

        out = out_cache.get(key, sentinel)
        if out is not sentinel:
            return out

        exc_tb = exc_tb_cache.get(key, sentinel)
        if exc_tb is not sentinel:
            raise exc_tb[0].with_traceback(exc_tb[1])

        try:
            out = fn(*args, **kwargs)
        except Exception as exc:
            # We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
            # traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
            # the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
            exc_tb_cache[key] = exc, exc.__traceback__
            raise exc

        out_cache[key] = out
        return out

    return wrapper


def combinations_grid(**kwargs):
    """Creates a grid of input combinations.

    Each element in the returned sequence is a dictionary containing one possible combination as values.

    Example:
        >>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
        [
            {'foo': 'bar', 'spam': 'eggs'},
            {'foo': 'bar', 'spam': 'ham'},
            {'foo': 'baz', 'spam': 'eggs'},
            {'foo': 'baz', 'spam': 'ham'}
        ]
    """
    return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


class ImagePair(TensorLikePair):
    def __init__(
        self,
        actual,
        expected,
        *,
        mae=False,
        **other_parameters,
    ):
        if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
            actual, expected = [to_image(input) for input in [actual, expected]]

        super().__init__(actual, expected, **other_parameters)
        self.mae = mae

    def compare(self) -> None:
        actual, expected = self.actual, self.expected

        self._compare_attributes(actual, expected)
        actual, expected = self._equalize_attributes(actual, expected)

        if self.mae:
            if actual.dtype is torch.uint8:
                actual, expected = actual.to(torch.int), expected.to(torch.int)
            mae = float(torch.abs(actual - expected).float().mean())
            if mae > self.atol:
                self._fail(
                    AssertionError,
                    f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
                )
        else:
            super()._compare_values(actual, expected)


def assert_close(
    actual,
    expected,
    *,
    allow_subclasses=True,
    rtol=None,
    atol=None,
    equal_nan=False,
    check_device=True,
    check_dtype=True,
    check_layout=True,
    check_stride=False,
    msg=None,
    **kwargs,
):
    """Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
    __tracebackhide__ = True

    error_metas = not_close_error_metas(
        actual,
        expected,
        pair_types=(
            NonePair,
            BooleanPair,
            NumberPair,
            ImagePair,
            TensorLikePair,
        ),
        allow_subclasses=allow_subclasses,
        rtol=rtol,
        atol=atol,
        equal_nan=equal_nan,
        check_device=check_device,
        check_dtype=check_dtype,
        check_layout=check_layout,
        check_stride=check_stride,
        **kwargs,
    )

    if error_metas:
        raise error_metas[0].to_error(msg)


assert_equal = functools.partial(assert_close, rtol=0, atol=0)


DEFAULT_SIZE = (17, 11)


NUM_CHANNELS_MAP = {
    "GRAY": 1,
    "GRAY_ALPHA": 2,
    "RGB": 3,
    "RGBA": 4,
}


def make_image(
    size=DEFAULT_SIZE,
    *,
    color_space="RGB",
    batch_dims=(),
    dtype=None,
    device="cpu",
    memory_format=torch.contiguous_format,
):
    num_channels = NUM_CHANNELS_MAP[color_space]
    dtype = dtype or torch.uint8
    max_value = get_max_value(dtype)
    data = torch.testing.make_tensor(
        (*batch_dims, num_channels, *size),
        low=0,
        high=max_value,
        dtype=dtype,
        device=device,
        memory_format=memory_format,
    )
    if color_space in {"GRAY_ALPHA", "RGBA"}:
        data[..., -1, :, :] = max_value

    return tv_tensors.Image(data)


def make_image_tensor(*args, **kwargs):
    return make_image(*args, **kwargs).as_subclass(torch.Tensor)


def make_image_pil(*args, **kwargs):
    return to_pil_image(make_image(*args, **kwargs))


def make_bounding_boxes(
    canvas_size=DEFAULT_SIZE,
    *,
    format=tv_tensors.BoundingBoxFormat.XYXY,
    num_boxes=1,
    dtype=None,
    device="cpu",
):
    def sample_position(values, max_value):
        # We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
        # However, if we have batch_dims, we need tensors as limits.
        return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])

    if isinstance(format, str):
        format = tv_tensors.BoundingBoxFormat[format]

    dtype = dtype or torch.float32

    h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
    y = sample_position(h, canvas_size[0])
    x = sample_position(w, canvas_size[1])

    if format is tv_tensors.BoundingBoxFormat.XYWH:
        parts = (x, y, w, h)
    elif format is tv_tensors.BoundingBoxFormat.XYXY:
        x1, y1 = x, y
        x2 = x1 + w
        y2 = y1 + h
        parts = (x1, y1, x2, y2)
    elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
        cx = x + w / 2
        cy = y + h / 2
        parts = (cx, cy, w, h)
438
    else:
limm's avatar
limm committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        raise ValueError(f"Format {format} is not supported")

    return tv_tensors.BoundingBoxes(
        torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
    )


def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
    """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
    return tv_tensors.Mask(
        torch.testing.make_tensor(
            (num_masks, *size),
            low=0,
            high=2,
            dtype=dtype or torch.bool,
            device=device,
        )
    )


def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
    """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
    return tv_tensors.Mask(
        torch.testing.make_tensor(
            (*batch_dims, *size),
            low=0,
            high=num_categories,
            dtype=dtype or torch.uint8,
            device=device,
        )
    )


def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
    return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))


def make_video_tensor(*args, **kwargs):
    return make_video(*args, **kwargs).as_subclass(torch.Tensor)


def assert_run_python_script(source_code):
    """Utility to check assertions in an independent Python subprocess.

    The script provided in the source code should return 0 and not print
    anything on stderr or stdout. Modified from scikit-learn test utils.

    Args:
        source_code (str): The Python source code to execute.
    """
    with get_tmp_dir() as root:
        path = pathlib.Path(root) / "main.py"
        with open(path, "w") as file:
            file.write(source_code)

        try:
            out = check_output([sys.executable, str(path)], stderr=STDOUT)
        except CalledProcessError as e:
            raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
        if out != b"":
            raise AssertionError(out.decode())


@contextlib.contextmanager
def assert_no_warnings():
    # The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
    # the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
    with warnings.catch_warnings():
        warnings.simplefilter("error")
        yield


@contextlib.contextmanager
def ignore_jit_no_profile_information_warning():
    # Calling a scripted object often triggers a warning like
    # `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
    # with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
    # them.
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
        yield