import os import shutil import tempfile import contextlib import unittest import pytest import argparse import sys import torch import __main__ import random import inspect import functools from numbers import Number from torch._six import string_classes from collections import OrderedDict import numpy as np from PIL import Image IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SKIP = pytest.mark.skipif(IS_PY39, reason=PY39_SEGFAULT_SKIP_MSG) IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda." @contextlib.contextmanager def get_tmp_dir(src=None, **kwargs): tmp_dir = tempfile.mkdtemp(**kwargs) if src is not None: os.rmdir(tmp_dir) shutil.copytree(src, tmp_dir) try: yield tmp_dir finally: shutil.rmtree(tmp_dir) def set_rng_seed(seed): torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) class MapNestedTensorObjectImpl(object): def __init__(self, tensor_map_fn): self.tensor_map_fn = tensor_map_fn def __call__(self, object): if isinstance(object, torch.Tensor): return self.tensor_map_fn(object) elif isinstance(object, dict): mapped_dict = {} for key, value in object.items(): mapped_dict[self(key)] = self(value) return mapped_dict elif isinstance(object, (list, tuple)): mapped_iter = [] for iter in object: mapped_iter.append(self(iter)) return mapped_iter if not isinstance(object, tuple) else tuple(mapped_iter) else: return object def map_nested_tensor_object(object, tensor_map_fn): impl = MapNestedTensorObjectImpl(tensor_map_fn) return impl(object) def is_iterable(obj): try: iter(obj) return True except TypeError: return False @contextlib.contextmanager def freeze_rng_state(): rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() yield if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) torch.set_rng_state(rng_state) def cycle_over(objs): for idx, obj1 in enumerate(objs): for obj2 in objs[:idx] + objs[idx + 1:]: yield obj1, obj2 def int_dtypes(): return torch.testing.integral_types() def float_dtypes(): return torch.testing.floating_types() @contextlib.contextmanager def disable_console_output(): with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull: stack.enter_context(contextlib.redirect_stdout(devnull)) stack.enter_context(contextlib.redirect_stderr(devnull)) yield def cpu_and_gpu(): import pytest # noqa return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda)) def needs_cuda(test_func): import pytest # noqa return pytest.mark.needs_cuda(test_func) def _create_data(height=3, width=3, channels=3, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) return tensor, pil_img def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture batch_tensor = torch.randint( 0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device ) return batch_tensor assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): np_pil_image = np.array(pil_image) if np_pil_image.ndim == 2: np_pil_image = np_pil_image[:, :, None] pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))) if msg is None: msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) assert_equal(tensor.cpu(), pil_tensor, msg=msg) def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None): # TODO: we could just merge this into _assert_equal_tensor_to_pil np_pil_image = np.array(pil_image) if np_pil_image.ndim == 2: np_pil_image = np_pil_image[:, :, None] pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor) if allowed_percentage_diff is not None: # Assert that less than a given %age of pixels are different assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff # error value can be mean absolute error, max abs error # Convert to float to avoid underflow when computing absolute difference tensor = tensor.to(torch.float) pil_tensor = pil_tensor.to(torch.float) err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item() assert err < tol def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): transformed_batch = fn(batch_tensors, **fn_kwargs) for i in range(len(batch_tensors)): img_tensor = batch_tensors[i, ...] transformed_img = fn(img_tensor, **fn_kwargs) assert_equal(transformed_img, transformed_batch[i, ...]) if scripted_fn_atol >= 0: scripted_fn = torch.jit.script(fn) # scriptable function test s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)