Commit bf491463 authored by limm's avatar limm
Browse files

add v0.19.1 release

parent e17f5ea2
import contextlib
import functools
import itertools
import os import os
import pathlib
import random
import re
import shutil import shutil
import tempfile
import contextlib
import unittest
import argparse
import sys import sys
import io import tempfile
import torch
import warnings import warnings
import __main__ from subprocess import CalledProcessError, check_output, STDOUT
import random
import inspect
from numbers import Number
from torch._six import string_classes
from collections import OrderedDict
from _utils_internal import get_relative_path
import numpy as np import numpy as np
import PIL.Image
import pytest
import torch
import torch.testing
from PIL import Image from PIL import Image
from _assert_utils import assert_equal from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_image, to_pil_image
IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
@contextlib.contextmanager @contextlib.contextmanager
...@@ -46,14 +47,9 @@ def get_tmp_dir(src=None, **kwargs): ...@@ -46,14 +47,9 @@ def get_tmp_dir(src=None, **kwargs):
def set_rng_seed(seed): def set_rng_seed(seed):
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
np.random.seed(seed)
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1'
TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1'
class MapNestedTensorObjectImpl:
class MapNestedTensorObjectImpl(object):
def __init__(self, tensor_map_fn): def __init__(self, tensor_map_fn):
self.tensor_map_fn = tensor_map_fn self.tensor_map_fn = tensor_map_fn
...@@ -90,230 +86,6 @@ def is_iterable(obj): ...@@ -90,230 +86,6 @@ def is_iterable(obj):
return False return False
# adapted from TestCase in torch/test/common_utils to accept non-string
# inputs and set maximum binary size
class TestCase(unittest.TestCase):
precision = 1e-5
def _get_expected_file(self, name=None):
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
# Determine expected file based on environment
expected_file_base = get_relative_path(
os.path.realpath(sys.modules[module_id].__file__),
"expect")
# Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names
# We hardcode it here to avoid having to re-generate the reference files
expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name)
expected_file += "_expect.pkl"
if not ACCEPT and not os.path.exists(expected_file):
raise RuntimeError(
f"No expect file exists for {os.path.basename(expected_file)} in {expected_file}; "
"to accept the current output, re-run the failing test after setting the EXPECTTEST_ACCEPT "
"env variable. For example: EXPECTTEST_ACCEPT=1 pytest test/test_models.py -k alexnet"
)
return expected_file
def assertExpected(self, output, name, prec=None):
r"""
Test that a python value matches the recorded contents of a file
based on a "check" name. The value must be
pickable with `torch.save`. This file
is placed in the 'expect' directory in the same directory
as the test script. You can automatically update the recorded test
output using an EXPECTTEST_ACCEPT=1 env variable.
"""
expected_file = self._get_expected_file(name)
if ACCEPT:
filename = {os.path.basename(expected_file)}
print("Accepting updated output for {}:\n\n{}".format(filename, output))
torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError("The output for {}, is larger than 50kb".format(filename))
else:
expected = torch.load(expected_file)
rtol = atol = prec or self.precision
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
"""
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual
"""
if isinstance(prec, str) and message == '':
message = prec
prec = None
if prec is None:
prec = self.precision
if isinstance(x, torch.Tensor) and isinstance(y, Number):
self.assertEqual(x.item(), y, prec=prec, message=message,
allow_inf=allow_inf)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
self.assertEqual(x, y.item(), prec=prec, message=message,
allow_inf=allow_inf)
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
def assertTensorsEqual(a, b):
super(TestCase, self).assertEqual(a.size(), b.size(), message)
if a.numel() > 0:
if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)):
# CPU half and bfloat16 tensors don't have the methods we need below
a = a.to(torch.float32)
b = b.to(a)
if (a.dtype == torch.bool) != (b.dtype == torch.bool):
raise TypeError("Was expecting both tensors to be bool type.")
else:
if a.dtype == torch.bool and b.dtype == torch.bool:
# we want to respect precision but as bool doesn't support substraction,
# boolean tensor has to be converted to int
a = a.to(torch.int)
b = b.to(torch.int)
diff = a - b
if a.is_floating_point():
# check that NaNs are in the same locations
nan_mask = torch.isnan(a)
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
diff[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = torch.isinf(a)
inf_sign = inf_mask.sign()
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
diff[inf_mask] = 0
# TODO: implement abs on CharTensor (int8)
if diff.is_signed() and diff.dtype != torch.int8:
diff = diff.abs()
max_err = diff.max()
tolerance = prec + prec * abs(a.max())
self.assertLessEqual(max_err, tolerance, message)
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message)
if x.is_sparse:
x = self.safeCoalesce(x)
y = self.safeCoalesce(y)
assertTensorsEqual(x._indices(), y._indices())
assertTensorsEqual(x._values(), y._values())
elif x.is_quantized and y.is_quantized:
self.assertEqual(x.qscheme(), y.qscheme(), prec=prec,
message=message, allow_inf=allow_inf)
if x.qscheme() == torch.per_tensor_affine:
self.assertEqual(x.q_scale(), y.q_scale(), prec=prec,
message=message, allow_inf=allow_inf)
self.assertEqual(x.q_zero_point(), y.q_zero_point(),
prec=prec, message=message,
allow_inf=allow_inf)
elif x.qscheme() == torch.per_channel_affine:
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec,
message=message, allow_inf=allow_inf)
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(),
prec=prec, message=message,
allow_inf=allow_inf)
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(),
prec=prec, message=message)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.int_repr().to(torch.int32),
y.int_repr().to(torch.int32), prec=prec,
message=message, allow_inf=allow_inf)
else:
assertTensorsEqual(x, y)
elif isinstance(x, string_classes) and isinstance(y, string_classes):
super(TestCase, self).assertEqual(x, y, message)
elif type(x) == set and type(y) == set:
super(TestCase, self).assertEqual(x, y, message)
elif isinstance(x, dict) and isinstance(y, dict):
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
self.assertEqual(x.items(), y.items(), prec=prec,
message=message, allow_inf=allow_inf)
else:
self.assertEqual(set(x.keys()), set(y.keys()), prec=prec,
message=message, allow_inf=allow_inf)
key_list = list(x.keys())
self.assertEqual([x[k] for k in key_list],
[y[k] for k in key_list],
prec=prec, message=message,
allow_inf=allow_inf)
elif is_iterable(x) and is_iterable(y):
super(TestCase, self).assertEqual(len(x), len(y), message)
for x_, y_ in zip(x, y):
self.assertEqual(x_, y_, prec=prec, message=message,
allow_inf=allow_inf)
elif isinstance(x, bool) and isinstance(y, bool):
super(TestCase, self).assertEqual(x, y, message)
elif isinstance(x, Number) and isinstance(y, Number):
inf = float("inf")
if abs(x) == inf or abs(y) == inf:
if allow_inf:
super(TestCase, self).assertEqual(x, y, message)
else:
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
return
super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
else:
super(TestCase, self).assertEqual(x, y, message)
def check_jit_scriptable(self, nn_module, args, unwrapper=None, skip=False):
"""
Check that a nn.Module's results in TorchScript match eager and that it
can be exported
"""
if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests
msg = "The check_jit_scriptable test for {} was skipped. " \
"This test checks if the module's results in TorchScript " \
"match eager and that it can be exported. To run these " \
"tests make sure you set the environment variable " \
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
"manually skipped.".format(nn_module.__class__.__name__)
warnings.warn(msg, RuntimeWarning)
return None
sm = torch.jit.script(nn_module)
with freeze_rng_state():
eager_out = nn_module(*args)
with freeze_rng_state():
script_out = sm(*args)
if unwrapper:
script_out = unwrapper(script_out)
self.assertEqual(eager_out, script_out, prec=1e-4)
self.assertExportImportModule(sm, args)
return sm
def getExportImportCopy(self, m):
"""
Save and load a TorchScript model
"""
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer)
return imported
def assertExportImportModule(self, m, args):
"""
Check that the results of a model are the same after saving and loading
"""
m_import = self.getExportImportCopy(m)
with freeze_rng_state():
results = m(*args)
with freeze_rng_state():
results_from_imported = m_import(*args)
self.assertEqual(results, results_from_imported, prec=3e-5)
@contextlib.contextmanager @contextlib.contextmanager
def freeze_rng_state(): def freeze_rng_state():
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
...@@ -325,65 +97,18 @@ def freeze_rng_state(): ...@@ -325,65 +97,18 @@ def freeze_rng_state():
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
class TransformsTester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
)
return batch_tensor
def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
allowed_percentage_diff=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
self.assertTrue(
(tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
)
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def cycle_over(objs): def cycle_over(objs):
for idx, obj in enumerate(objs): for idx, obj1 in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:] for obj2 in objs[:idx] + objs[idx + 1 :]:
yield obj1, obj2
def int_dtypes(): def int_dtypes():
return torch.testing.integral_types() return (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
def float_dtypes(): def float_dtypes():
return torch.testing.floating_types() return (torch.float32, torch.float64)
@contextlib.contextmanager @contextlib.contextmanager
...@@ -394,66 +119,401 @@ def disable_console_output(): ...@@ -394,66 +119,401 @@ def disable_console_output():
yield yield
def call_args_to_kwargs_only(call_args, *callable_or_arg_names): def cpu_and_cuda():
callable_or_arg_name = callable_or_arg_names[0] import pytest # noqa
if callable(callable_or_arg_name):
argspec = inspect.getfullargspec(callable_or_arg_name) return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
arg_names = argspec.args
if isinstance(callable_or_arg_name, type):
# remove self
arg_names.pop(0)
else:
arg_names = callable_or_arg_names
args, kwargs = call_args
kwargs_only = kwargs.copy()
kwargs_only.update(dict(zip(arg_names, args)))
return kwargs_only
def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
def cpu_and_gpu():
# TODO: make this properly handle CircleCI def needs_cuda(test_func):
import pytest # noqa import pytest # noqa
# ignore CPU tests in RE as they're already covered by another contbuild return pytest.mark.needs_cuda(test_func)
devices = [] if IN_RE_WORKER else ['cpu']
if torch.cuda.is_available():
cuda_marks = ()
elif IN_FBCODE:
# Dont collect cuda tests on fbcode if the machine doesnt have a GPU
# This avoids skipping the tests. More robust would be to detect if
# we're in sancastle instead of fbcode?
cuda_marks = pytest.mark.dont_collect()
else:
cuda_marks = pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)
devices.append(pytest.param('cuda', marks=cuda_marks)) def needs_mps(test_func):
import pytest # noqa
return devices return pytest.mark.needs_mps(test_func)
def needs_cuda(test_func): def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: make this properly handle CircleCI # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
import pytest # noqa tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
data = tensor.permute(1, 2, 0).contiguous().cpu().numpy()
mode = "RGB"
if channels == 1:
mode = "L"
data = data[..., 0]
pil_img = Image.fromarray(data, mode=mode)
return tensor, pil_img
if IN_FBCODE and not IN_RE_WORKER:
# We don't want to skip in fbcode, so we just don't collect
# TODO: slightly more robust way would be to detect if we're in a sandcastle instance
# so that the test will still be collected (and skipped) in the devvms.
return pytest.mark.dont_collect(test_func)
elif torch.cuda.is_available():
return test_func
else:
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)
def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device)
return batch_tensor
def cpu_only(test_func):
# TODO: make this properly handle CircleCI
import pytest # noqa
if IN_RE_WORKER: def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
# The assumption is that all RE workers have GPUs. names = []
return pytest.mark.dont_collect(test_func) for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, f"{i}.mp4")
names.append(name)
io.write_video(name, data, fps=f)
return names
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
# FIXME: this is handled automatically by `assert_equal` below. Let's remove this in favor of it
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
assert_equal(tensor.cpu(), pil_tensor, msg=msg)
def _assert_approx_equal_tensor_to_pil(
tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
):
# FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
# TODO: we could just merge this into _assert_equal_tensor_to_pil
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
if allowed_percentage_diff is not None:
# Assert that less than a given %age of pixels are different
assert (tensor != pil_tensor).to(torch.float).mean() <= allowed_percentage_diff
# error value can be mean absolute error, max abs error
# Convert to float to avoid underflow when computing absolute difference
tensor = tensor.to(torch.float)
pil_tensor = pil_tensor.to(torch.float)
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
assert err < tol, f"{err} vs {tol}"
def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
transformed_batch = fn(batch_tensors, **fn_kwargs)
for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
transformed_img = fn(img_tensor, **fn_kwargs)
torch.testing.assert_close(transformed_img, transformed_batch[i, ...], rtol=0, atol=1e-6)
if scripted_fn_atol >= 0:
scripted_fn = torch.jit.script(fn)
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
def cache(fn):
"""Similar to :func:`functools.cache` (Python >= 3.8) or :func:`functools.lru_cache` with infinite cache size,
but this also caches exceptions.
"""
sentinel = object()
out_cache = {}
exc_tb_cache = {}
@functools.wraps(fn)
def wrapper(*args, **kwargs):
key = args + tuple(kwargs.values())
out = out_cache.get(key, sentinel)
if out is not sentinel:
return out
exc_tb = exc_tb_cache.get(key, sentinel)
if exc_tb is not sentinel:
raise exc_tb[0].with_traceback(exc_tb[1])
try:
out = fn(*args, **kwargs)
except Exception as exc:
# We need to cache the traceback here as well. Otherwise, each re-raise will add the internal pytest
# traceback frames anew, but they will only be removed once. Thus, the traceback will be ginormous hiding
# the actual information in the noise. See https://github.com/pytest-dev/pytest/issues/10363 for details.
exc_tb_cache[key] = exc, exc.__traceback__
raise exc
out_cache[key] = out
return out
return wrapper
def combinations_grid(**kwargs):
"""Creates a grid of input combinations.
Each element in the returned sequence is a dictionary containing one possible combination as values.
Example:
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
[
{'foo': 'bar', 'spam': 'eggs'},
{'foo': 'bar', 'spam': 'ham'},
{'foo': 'baz', 'spam': 'eggs'},
{'foo': 'baz', 'spam': 'ham'}
]
"""
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
class ImagePair(TensorLikePair):
def __init__(
self,
actual,
expected,
*,
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image(input) for input in [actual, expected]]
super().__init__(actual, expected, **other_parameters)
self.mae = mae
def compare(self) -> None:
actual, expected = self.actual, self.expected
self._compare_attributes(actual, expected)
actual, expected = self._equalize_attributes(actual, expected)
if self.mae:
if actual.dtype is torch.uint8:
actual, expected = actual.to(torch.int), expected.to(torch.int)
mae = float(torch.abs(actual - expected).float().mean())
if mae > self.atol:
self._fail(
AssertionError,
f"The MAE of the images is {mae}, but only {self.atol} is allowed.",
)
else:
super()._compare_values(actual, expected)
def assert_close(
actual,
expected,
*,
allow_subclasses=True,
rtol=None,
atol=None,
equal_nan=False,
check_device=True,
check_dtype=True,
check_layout=True,
check_stride=False,
msg=None,
**kwargs,
):
"""Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
__tracebackhide__ = True
error_metas = not_close_error_metas(
actual,
expected,
pair_types=(
NonePair,
BooleanPair,
NumberPair,
ImagePair,
TensorLikePair,
),
allow_subclasses=allow_subclasses,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
check_stride=check_stride,
**kwargs,
)
if error_metas:
raise error_metas[0].to_error(msg)
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
DEFAULT_SIZE = (17, 11)
NUM_CHANNELS_MAP = {
"GRAY": 1,
"GRAY_ALPHA": 2,
"RGB": 3,
"RGBA": 4,
}
def make_image(
size=DEFAULT_SIZE,
*,
color_space="RGB",
batch_dims=(),
dtype=None,
device="cpu",
memory_format=torch.contiguous_format,
):
num_channels = NUM_CHANNELS_MAP[color_space]
dtype = dtype or torch.uint8
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(
(*batch_dims, num_channels, *size),
low=0,
high=max_value,
dtype=dtype,
device=device,
memory_format=memory_format,
)
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value
return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs):
return make_image(*args, **kwargs).as_subclass(torch.Tensor)
def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))
def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=tv_tensors.BoundingBoxFormat.XYXY,
num_boxes=1,
dtype=None,
device="cpu",
):
def sample_position(values, max_value):
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
# However, if we have batch_dims, we need tensors as limits.
return torch.stack([torch.randint(max_value - v, ()) for v in values.tolist()])
if isinstance(format, str):
format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y
x2 = x1 + w
y2 = y1 + h
parts = (x1, y1, x2, y2)
elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
else: else:
return test_func raise ValueError(f"Format {format} is not supported")
return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return tv_tensors.Mask(
torch.testing.make_tensor(
(num_masks, *size),
low=0,
high=2,
dtype=dtype or torch.bool,
device=device,
)
)
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, *size),
low=0,
high=num_categories,
dtype=dtype or torch.uint8,
device=device,
)
)
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
def assert_run_python_script(source_code):
"""Utility to check assertions in an independent Python subprocess.
The script provided in the source code should return 0 and not print
anything on stderr or stdout. Modified from scikit-learn test utils.
Args:
source_code (str): The Python source code to execute.
"""
with get_tmp_dir() as root:
path = pathlib.Path(root) / "main.py"
with open(path, "w") as file:
file.write(source_code)
try:
out = check_output([sys.executable, str(path)], stderr=STDOUT)
except CalledProcessError as e:
raise RuntimeError(f"script errored with output:\n{e.output.decode()}")
if out != b"":
raise AssertionError(out.decode())
@contextlib.contextmanager
def assert_no_warnings():
# The name `catch_warnings` is a misnomer as the context manager does **not** catch any warnings, but rather scopes
# the warning filters. All changes that are made to the filters while in this context, will be reset upon exit.
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
@contextlib.contextmanager
def ignore_jit_no_profile_information_warning():
# Calling a scripted object often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
yield
import random
import numpy as np
import pytest
import torch
from common_utils import (
CUDA_NOT_AVAILABLE_MSG,
IN_FBCODE,
IN_OSS_CI,
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
)
def pytest_configure(config): def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems) # register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line( config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
"markers", "dont_collect: marks a test that should not be collected (avoids skipping it)" config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
) config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
config.addinivalue_line("markers", "opcheck_only_one: only opcheck one parametrization")
def pytest_collection_modifyitems(items): def pytest_collection_modifyitems(items):
# This hook is called by pytest after it has collected the tests (google its name!) # This hook is called by pytest after it has collected the tests (google its name to check out its doc!)
# We can ignore some tests as we see fit here. In particular we ignore the tests that # We can ignore some tests as we see fit here, or add marks, such as a skip mark.
# we have marked with the custom 'dont_collect' mark. This avoids skipping the tests, #
# since the internal fb infra doesn't like skipping tests. # Typically, here, we try to optimize CI time. In particular, the GPU CI instances don't need to run the
to_keep = [item for item in items if item.get_closest_marker('dont_collect') is None] # tests that don't need CUDA, because those tests are extensively tested in the CPU CI instances already.
items[:] = to_keep # This is true for both OSS CI and the fbcode internal CI.
# In the fbcode CI, we have an additional constraint: we try to avoid skipping tests. So instead of relying on
# pytest.mark.skip, in fbcode we literally just remove those tests from the `items` list, and it's as if
# these tests never existed.
out_items = []
for item in items:
# The needs_cuda mark will exist if the test was explicitly decorated with
# the @needs_cuda decorator. It will also exist if it was parametrized with a
# parameter that has the mark: for example if a test is parametrized with
# @pytest.mark.parametrize('device', cpu_and_cuda())
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None
if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))
if IN_FBCODE:
# fbcode doesn't like skipping tests, so instead we just don't collect the test
# so that they don't even "exist", hence the continue statements.
if not needs_cuda and IN_RE_WORKER:
# The RE workers are the machines with GPU, we don't want them to run CPU-only tests.
continue
if needs_cuda and not torch.cuda.is_available():
# On the test machines without a GPU, we want to ignore the tests that need cuda.
# TODO: something more robust would be to do that only in a sandcastle instance,
# so that we can still see the test being skipped when testing locally from a devvm
continue
if needs_mps and not torch.backends.mps.is_available():
# Same as above, but for MPS
continue
elif IN_OSS_CI:
# Here we're not in fbcode, so we can safely collect and skip tests.
if not needs_cuda and torch.cuda.is_available():
# Similar to what happens in RE workers: we don't need the OSS CI GPU machines
# to run the CPU-only tests.
item.add_marker(pytest.mark.skip(reason=OSS_CI_GPU_NO_CUDA_MSG))
if item.get_closest_marker("dont_collect") is not None:
# currently, this is only used for some tests we're sure we don't want to run on fbcode
continue
out_items.append(item)
items[:] = out_items
def pytest_sessionfinish(session, exitstatus):
# This hook is called after all tests have run, and just before returning an exit status.
# We here change exit code 5 into 0.
#
# 5 is issued when no tests were actually run, e.g. if you use `pytest -k some_regex_that_is_never_matched`.
#
# Having no test being run for a given test rule is a common scenario in fbcode, and typically happens on
# the GPU test machines which don't run the CPU-only tests (see pytest_collection_modifyitems above). For
# example `test_transforms.py` doesn't contain any CUDA test at the time of
# writing, so on a GPU test machine, testpilot would invoke pytest on this file and no test would be run.
# This would result in pytest returning 5, causing testpilot to raise an error.
# To avoid this, we transform this 5 into a 0 to make testpilot happy.
if exitstatus == 5:
session.exitstatus = 0
@pytest.fixture(autouse=True)
def prevent_leaking_rng():
# Prevent each test from leaking the rng to all other test when they call
# torch.manual_seed() or random.seed() or np.random.seed().
# Note: the numpy rngs should never leak anyway, as we never use
# np.random.seed() and instead rely on np.random.RandomState instances (see
# issue #4247). We still do it for extra precaution.
torch_rng_state = torch.get_rng_state()
builtin_rng_state = random.getstate()
nunmpy_rng_state = np.random.get_state()
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
yield
torch.set_rng_state(torch_rng_state)
random.setstate(builtin_rng_state)
np.random.set_state(nunmpy_rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
...@@ -18,7 +18,7 @@ TEST(test_custom_operators, nms) { ...@@ -18,7 +18,7 @@ TEST(test_custom_operators, nms) {
double thresh = 0.7; double thresh = 0.7;
torch::jit::push(stack, boxes, scores, thresh); torch::jit::push(stack, boxes, scores, thresh);
op->getOperation()(&stack); op->getOperation()(stack);
at::Tensor output_jit; at::Tensor output_jit;
torch::jit::pop(stack, output_jit); torch::jit::pop(stack, output_jit);
...@@ -47,7 +47,7 @@ TEST(test_custom_operators, roi_align_visible) { ...@@ -47,7 +47,7 @@ TEST(test_custom_operators, roi_align_visible) {
bool aligned = true; bool aligned = true;
torch::jit::push(stack, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned); torch::jit::push(stack, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned);
op->getOperation()(&stack); op->getOperation()(stack);
at::Tensor output_jit; at::Tensor output_jit;
torch::jit::pop(stack, output_jit); torch::jit::pop(stack, output_jit);
......
...@@ -5,20 +5,33 @@ import inspect ...@@ -5,20 +5,33 @@ import inspect
import itertools import itertools
import os import os
import pathlib import pathlib
import platform
import random import random
import shutil
import string import string
import struct
import tarfile
import unittest import unittest
import unittest.mock import unittest.mock
import zipfile
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL import PIL
import PIL.Image import PIL.Image
import pytest
import torch import torch
import torchvision.datasets import torchvision.datasets
import torchvision.io import torchvision.io
from common_utils import disable_console_output, get_tmp_dir
from common_utils import get_tmp_dir, disable_console_output from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size
__all__ = [ __all__ = [
...@@ -33,6 +46,8 @@ __all__ = [ ...@@ -33,6 +46,8 @@ __all__ = [
"create_image_folder", "create_image_folder",
"create_video_file", "create_video_file",
"create_video_folder", "create_video_folder",
"make_tar",
"make_zip",
"create_random_string", "create_random_string",
] ]
...@@ -55,6 +70,7 @@ class LazyImporter: ...@@ -55,6 +70,7 @@ class LazyImporter:
"requests", "requests",
"scipy.io", "scipy.io",
"scipy.sparse", "scipy.sparse",
"h5py",
) )
def __init__(self): def __init__(self):
...@@ -127,16 +143,16 @@ def test_all_configs(test): ...@@ -127,16 +143,16 @@ def test_all_configs(test):
.. note:: .. note::
This will try to remove duplicate configurations. During this process it will not not preserve a potential This will try to remove duplicate configurations. During this process it will not preserve a potential
ordering of the configurations or an inner ordering of a configuration. ordering of the configurations or an inner ordering of a configuration.
""" """
def maybe_remove_duplicates(configs): def maybe_remove_duplicates(configs):
try: try:
return [dict(config_) for config_ in set(tuple(sorted(config.items())) for config in configs)] return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}]
except TypeError: except TypeError:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate # A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out. # removal would be a lot more elaborate, and we simply bail out.
return configs return configs
@functools.wraps(test) @functools.wraps(test)
...@@ -159,23 +175,6 @@ def test_all_configs(test): ...@@ -159,23 +175,6 @@ def test_all_configs(test):
return wrapper return wrapper
def combinations_grid(**kwargs):
"""Creates a grid of input combinations.
Each element in the returned sequence is a dictionary containing one possible combination as values.
Example:
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
[
{'foo': 'bar', 'spam': 'eggs'},
{'foo': 'bar', 'spam': 'ham'},
{'foo': 'baz', 'spam': 'eggs'},
{'foo': 'baz', 'spam': 'ham'}
]
"""
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
class DatasetTestCase(unittest.TestCase): class DatasetTestCase(unittest.TestCase):
"""Abstract base class for all dataset testcases. """Abstract base class for all dataset testcases.
...@@ -287,7 +286,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -287,7 +286,7 @@ class DatasetTestCase(unittest.TestCase):
.. note:: .. note::
The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter. The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
Otherwise you need to overwrite this method. Otherwise, you need to overwrite this method.
Args: Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
...@@ -416,7 +415,11 @@ class DatasetTestCase(unittest.TestCase): ...@@ -416,7 +415,11 @@ class DatasetTestCase(unittest.TestCase):
continue continue
defaults.append( defaults.append(
{kwarg: default for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)} {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)
if not kwarg.startswith("_")
}
) )
if not argspec.varkw: if not argspec.varkw:
...@@ -515,18 +518,18 @@ class DatasetTestCase(unittest.TestCase): ...@@ -515,18 +518,18 @@ class DatasetTestCase(unittest.TestCase):
yield mocks yield mocks
def test_not_found_or_corrupted(self): def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)): with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False): with self.create_dataset(inject_fake_data=False):
pass pass
def test_smoke(self): def test_smoke(self):
with self.create_dataset() as (dataset, _): with self.create_dataset() as (dataset, _):
self.assertIsInstance(dataset, torchvision.datasets.VisionDataset) assert isinstance(dataset, torchvision.datasets.VisionDataset)
@test_all_configs @test_all_configs
def test_str_smoke(self, config): def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, _):
self.assertIsInstance(str(dataset), str) assert isinstance(str(dataset), str)
@test_all_configs @test_all_configs
def test_feature_types(self, config): def test_feature_types(self, config):
...@@ -536,23 +539,21 @@ class DatasetTestCase(unittest.TestCase): ...@@ -536,23 +539,21 @@ class DatasetTestCase(unittest.TestCase):
if len(self.FEATURE_TYPES) > 1: if len(self.FEATURE_TYPES) > 1:
actual = len(example) actual = len(example)
expected = len(self.FEATURE_TYPES) expected = len(self.FEATURE_TYPES)
self.assertEqual( assert (
actual, actual == expected
expected, ), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: " f"{actual} != {expected}"
f"{actual} != {expected}",
)
else: else:
example = (example,) example = (example,)
for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)): for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx): with self.subTest(idx=idx):
self.assertIsInstance(feature, expected_feature_type) assert isinstance(feature, expected_feature_type)
@test_all_configs @test_all_configs
def test_num_examples(self, config): def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info): with self.create_dataset(config) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"]) assert len(list(dataset)) == len(dataset) == info["num_examples"]
@test_all_configs @test_all_configs
def test_transforms(self, config): def test_transforms(self, config):
...@@ -569,6 +570,39 @@ class DatasetTestCase(unittest.TestCase): ...@@ -569,6 +570,39 @@ class DatasetTestCase(unittest.TestCase):
mock.assert_called() mock.assert_called()
@test_all_configs
def test_transforms_v2_wrapper(self, config):
try:
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
torchvision.datasets.VOCDetection,
torchvision.datasets.Kitti,
torchvision.datasets.WIDERFace,
}:
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
continue
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]
wrapped_sample = wrapped_dataset[0]
assert tree_any(
lambda item: isinstance(item, (tv_tensors.TVTensor, PIL.Image.Image)), wrapped_sample
)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
pytest.skip(msg)
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
pytest.skip("Config is currently not supported by this wrapper")
raise error
class ImageDatasetTestCase(DatasetTestCase): class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases. """Abstract base class for image dataset testcases.
...@@ -592,7 +626,7 @@ class ImageDatasetTestCase(DatasetTestCase): ...@@ -592,7 +626,7 @@ class ImageDatasetTestCase(DatasetTestCase):
patch_checks=patch_checks, patch_checks=patch_checks,
**kwargs, **kwargs,
) as (dataset, info): ) as (dataset, info):
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access # PIL.Image.open() only loads the image metadata upfront and keeps the file open until the first access
# to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we # to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we
# force-load opened images. # force-load opened images.
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an # This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
...@@ -629,27 +663,76 @@ class VideoDatasetTestCase(DatasetTestCase): ...@@ -629,27 +663,76 @@ class VideoDatasetTestCase(DatasetTestCase):
FEATURE_TYPES = (torch.Tensor, torch.Tensor, int) FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
REQUIRED_PACKAGES = ("av",) REQUIRED_PACKAGES = ("av",)
DEFAULT_FRAMES_PER_CLIP = 1 FRAMES_PER_CLIP = 1
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args) self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)
def _set_default_frames_per_clip(self, inject_fake_data): def _set_default_frames_per_clip(self, dataset_args):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
args_without_default = argspec.args[1:-len(argspec.defaults)] args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
frames_per_clip_last = args_without_default[-1] == "frames_per_clip" frames_per_clip_last = args_without_default[-1] == "frames_per_clip"
@functools.wraps(inject_fake_data) @functools.wraps(dataset_args)
def wrapper(tmpdir, config): def wrapper(tmpdir, config):
args = inject_fake_data(tmpdir, config) args = dataset_args(tmpdir, config)
if frames_per_clip_last and len(args) == len(args_without_default) - 1: if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP) args = (*args, self.FRAMES_PER_CLIP)
return args return args
return wrapper return wrapper
def test_output_format(self):
for output_format in ["TCHW", "THWC"]:
with self.create_dataset(output_format=output_format) as (dataset, _):
for video, *_ in dataset:
if output_format == "TCHW":
num_frames, num_channels, *_ = video.shape
else: # output_format == "THWC":
num_frames, *_, num_channels = video.shape
assert num_frames == self.FRAMES_PER_CLIP
assert num_channels == 3
@test_all_configs
def test_transforms_v2_wrapper(self, config):
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
# or use the supported `"TCHW"`
if config.setdefault("output_format", "TCHW") == "THWC":
return
super().test_transforms_v2_wrapper.__wrapped__(self, config)
def _no_collate(batch):
return batch
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.
# To save CI/test time, we only check on Windows where "spawn" is the default
if platform.system() != "Windows":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
def resize_was_applied(item):
# Checking the size of the output ensures that the Resize transform was correctly applied
return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
)
for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor. r"""Create a random uint8 tensor.
...@@ -739,6 +822,33 @@ def create_image_folder( ...@@ -739,6 +822,33 @@ def create_image_folder(
] ]
def shape_test_for_stereo(
left: PIL.Image.Image,
right: PIL.Image.Image,
disparity: Optional[np.ndarray] = None,
valid_mask: Optional[np.ndarray] = None,
):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, h, w = left_dims
# check that left and right are the same size
assert left_dims == right_dims
assert c == 3
# check that the disparity has the same spatial dimensions
# as the input
if disparity is not None:
assert disparity.ndim == 3
assert disparity.shape == (1, h, w)
if valid_mask is not None:
# check that valid mask is the same size as the disparity
_, dh, dw = disparity.shape
mh, mw = valid_mask.shape
assert dh == mh
assert dw == mw
@requires_lazy_imports("av") @requires_lazy_imports("av")
def create_video_file( def create_video_file(
root: Union[pathlib.Path, str], root: Union[pathlib.Path, str],
...@@ -747,7 +857,7 @@ def create_video_file( ...@@ -747,7 +857,7 @@ def create_video_file(
fps: float = 25, fps: float = 25,
**kwargs: Any, **kwargs: Any,
) -> pathlib.Path: ) -> pathlib.Path:
"""Create an video file from random data. """Create a video file from random data.
Args: Args:
root (Union[str, pathlib.Path]): Root directory the video file will be placed in. root (Union[str, pathlib.Path]): Root directory the video file will be placed in.
...@@ -833,12 +943,86 @@ def create_video_folder( ...@@ -833,12 +943,86 @@ def create_video_folder(
] ]
def _split_files_or_dirs(root, *files_or_dirs):
files = set()
dirs = set()
for file_or_dir in files_or_dirs:
path = pathlib.Path(file_or_dir)
if not path.is_absolute():
path = root / path
if path.is_file():
files.add(path)
else:
dirs.add(path)
for sub_file_or_dir in path.glob("**/*"):
if sub_file_or_dir.is_file():
files.add(sub_file_or_dir)
else:
dirs.add(sub_file_or_dir)
if root in dirs:
dirs.remove(root)
return files, dirs
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
if not files_or_dirs:
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
file_or_dir = archive
for _ in range(len(archive.suffixes)):
file_or_dir = file_or_dir.with_suffix("")
if file_or_dir.exists():
files_or_dirs = (file_or_dir,)
else:
raise ValueError("No file or dir provided.")
files, dirs = _split_files_or_dirs(root, *files_or_dirs)
with opener(archive) as fh:
for file in sorted(files):
adder(fh, file, file.relative_to(root))
if remove:
for file in files:
os.remove(file)
for dir in dirs:
shutil.rmtree(dir, ignore_errors=True)
return archive
def make_tar(root, name, *files_or_dirs, remove=True, compression=None):
# TODO: detect compression from name
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"),
adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file),
remove=remove,
)
def make_zip(root, name, *files_or_dirs, remove=True):
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: zipfile.ZipFile(archive, "w"),
adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file),
remove=remove,
)
def create_random_string(length: int, *digits: str) -> str: def create_random_string(length: int, *digits: str) -> str:
"""Create a random string. """Create a random string.
Args: Args:
length (int): Number of characters in the generated string. length (int): Number of characters in the generated string.
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`. *digits (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
""" """
if not digits: if not digits:
digits = string.ascii_lowercase digits = string.ascii_lowercase
...@@ -846,3 +1030,26 @@ def create_random_string(length: int, *digits: str) -> str: ...@@ -846,3 +1030,26 @@ def create_random_string(length: int, *digits: str) -> str:
digits = "".join(itertools.chain(*digits)) digits = "".join(itertools.chain(*digits))
return "".join(random.choice(digits) for _ in range(length)) return "".join(random.choice(digits) for _ in range(length))
def make_fake_pfm_file(h, w, file_name):
values = list(range(3 * h * w))
# Note: we pack everything in little endian: -1.0, and "<"
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)
def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format."""
# Everything needs to be in little Endian according to
# https://vision.middlebury.edu/flow/code/flow-code/README.txt
values = list(range(2 * h * w))
content = (
struct.pack("<4c", *(c.encode() for c in "PIEH"))
+ struct.pack("<i", w)
+ struct.pack("<i", h)
+ struct.pack("<" + "f" * len(values), *values)
)
with open(file_name, "wb") as f:
f.write(content)
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment