"src/vscode:/vscode.git/clone" did not exist on "9b5180cb5f00799ec47b778533db9dcbf83ceda4"
Unverified Commit 90a2402b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup test suite related to `torch.testing.assert_close` (#4177)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent bf2fe567
"""This is a temporary module and should be removed as soon as torch.testing.assert_equal is supported."""
# TODO: remove this as soon torch.testing.assert_equal is supported
import functools
import torch.testing
__all__ = ["assert_equal"]
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import __main__ import __main__
import random import random
import inspect import inspect
import functools
from numbers import Number from numbers import Number
from torch._six import string_classes from torch._six import string_classes
...@@ -17,8 +18,6 @@ from collections import OrderedDict ...@@ -17,8 +18,6 @@ from collections import OrderedDict
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from _assert_utils import assert_equal
IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
...@@ -268,6 +267,9 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu ...@@ -268,6 +267,9 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
return batch_tensor return batch_tensor
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image) np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2: if np_pil_image.ndim == 2:
...@@ -275,7 +277,7 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): ...@@ -275,7 +277,7 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))) pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None: if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg) assert_equal(tensor.cpu(), pil_tensor, msg=msg)
def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
......
...@@ -13,8 +13,7 @@ from torchvision.datasets.samplers import ( ...@@ -13,8 +13,7 @@ from torchvision.datasets.samplers import (
from torchvision.datasets.video_utils import VideoClips, unfold from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend from torchvision import get_video_backend
from common_utils import get_tmp_dir from common_utils import get_tmp_dir, assert_equal
from _assert_utils import assert_equal
@contextlib.contextmanager @contextlib.contextmanager
......
...@@ -6,8 +6,7 @@ import pytest ...@@ -6,8 +6,7 @@ import pytest
from torchvision import io from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold from torchvision.datasets.video_utils import VideoClips, unfold
from common_utils import get_tmp_dir from common_utils import get_tmp_dir, assert_equal
from _assert_utils import assert_equal
@contextlib.contextmanager @contextlib.contextmanager
...@@ -41,7 +40,7 @@ class TestVideo: ...@@ -41,7 +40,7 @@ class TestVideo:
[0, 1, 2], [0, 1, 2],
[3, 4, 5], [3, 4, 5],
]) ])
assert_equal(r, expected, check_stride=False) assert_equal(r, expected)
r = unfold(a, 3, 2, 1) r = unfold(a, 3, 2, 1)
expected = torch.tensor([ expected = torch.tensor([
...@@ -49,14 +48,14 @@ class TestVideo: ...@@ -49,14 +48,14 @@ class TestVideo:
[2, 3, 4], [2, 3, 4],
[4, 5, 6] [4, 5, 6]
]) ])
assert_equal(r, expected, check_stride=False) assert_equal(r, expected)
r = unfold(a, 3, 2, 2) r = unfold(a, 3, 2, 2)
expected = torch.tensor([ expected = torch.tensor([
[0, 2, 4], [0, 2, 4],
[2, 4, 6], [2, 4, 6],
]) ])
assert_equal(r, expected, check_stride=False) assert_equal(r, expected)
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips(self): def test_video_clips(self):
......
...@@ -21,8 +21,8 @@ from common_utils import ( ...@@ -21,8 +21,8 @@ from common_utils import (
_assert_equal_tensor_to_pil, _assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil, _assert_approx_equal_tensor_to_pil,
_test_fn_on_batch, _test_fn_on_batch,
assert_equal,
) )
from _assert_utils import assert_equal
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
...@@ -187,11 +187,7 @@ class TestAffine: ...@@ -187,11 +187,7 @@ class TestAffine:
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
if config is not None: if config is not None:
assert_equal( assert_equal(torch.rot90(tensor, **config), out_tensor)
torch.rot90(tensor, **config),
out_tensor,
check_stride=False,
)
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -856,7 +852,6 @@ def test_resized_crop(device, mode): ...@@ -856,7 +852,6 @@ def test_resized_crop(device, mode):
assert_equal( assert_equal(
expected_out_tensor, expected_out_tensor,
out_tensor, out_tensor,
check_stride=False,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]), msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
) )
...@@ -1001,10 +996,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1001,10 +996,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) ).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
out = fn(tensor, kernel_size=ksize, sigma=sigma) out = fn(tensor, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close( torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma))
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
msg="{}, {}".format(ksize, sigma)
)
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
......
...@@ -9,8 +9,7 @@ import numpy as np ...@@ -9,8 +9,7 @@ import numpy as np
import torch import torch
from PIL import Image, __version__ as PILLOW_VERSION from PIL import Image, __version__ as PILLOW_VERSION
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda from common_utils import get_tmp_dir, needs_cuda, assert_equal
from _assert_utils import assert_equal
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
...@@ -280,7 +279,7 @@ def test_read_1_bit_png(shape): ...@@ -280,7 +279,7 @@ def test_read_1_bit_png(shape):
img.save(image_path) img.save(image_path)
img1 = read_image(image_path) img1 = read_image(image_path)
img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8)) img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
assert_equal(img1, img2, check_stride=False) assert_equal(img1, img2)
@pytest.mark.parametrize('shape', [ @pytest.mark.parametrize('shape', [
......
...@@ -9,8 +9,7 @@ from torchvision import get_video_backend ...@@ -9,8 +9,7 @@ from torchvision import get_video_backend
import warnings import warnings
from urllib.error import URLError from urllib.error import URLError
from common_utils import get_tmp_dir from common_utils import get_tmp_dir, assert_equal
from _assert_utils import assert_equal
try: try:
......
import torch import torch
from common_utils import TestCase from common_utils import TestCase, assert_equal
from _assert_utils import assert_equal
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
from torchvision.models.detection.image_list import ImageList from torchvision.models.detection.image_list import ImageList
import pytest import pytest
......
...@@ -7,7 +7,7 @@ from torchvision.models.detection.roi_heads import RoIHeads ...@@ -7,7 +7,7 @@ from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
import pytest import pytest
from _assert_utils import assert_equal from common_utils import assert_equal
class TestModelsDetectionNegativeSamples: class TestModelsDetectionNegativeSamples:
......
...@@ -4,7 +4,7 @@ from torchvision.models.detection import _utils ...@@ -4,7 +4,7 @@ from torchvision.models.detection import _utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
import pytest import pytest
from torchvision.models.detection import backbone_utils from torchvision.models.detection import backbone_utils
from _assert_utils import assert_equal from common_utils import assert_equal
class TestModelsDetectionUtils: class TestModelsDetectionUtils:
......
...@@ -6,8 +6,7 @@ try: ...@@ -6,8 +6,7 @@ try:
except ImportError: except ImportError:
onnxruntime = None onnxruntime = None
from common_utils import set_rng_seed from common_utils import set_rng_seed, assert_equal
from _assert_utils import assert_equal
import io import io
import torch import torch
from torchvision import ops from torchvision import ops
......
from common_utils import needs_cuda, cpu_and_gpu from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from _assert_utils import assert_equal
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import pytest import pytest
......
...@@ -19,8 +19,7 @@ try: ...@@ -19,8 +19,7 @@ try:
except ImportError: except ImportError:
stats = None stats = None
from common_utils import cycle_over, int_dtypes, float_dtypes from common_utils import cycle_over, int_dtypes, float_dtypes, assert_equal
from _assert_utils import assert_equal
GRACE_HOPPER = get_file_path_2( GRACE_HOPPER = get_file_path_2(
...@@ -159,7 +158,7 @@ class TestAccImage: ...@@ -159,7 +158,7 @@ class TestAccImage:
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
assert expected_output.size() == output.size() assert expected_output.size() == output.size()
torch.testing.assert_close(output, expected_output, check_stride=False) torch.testing.assert_close(output, expected_output)
def test_accimage_resize(self): def test_accimage_resize(self):
trans = transforms.Compose([ trans = transforms.Compose([
...@@ -205,23 +204,23 @@ class TestToTensor: ...@@ -205,23 +204,23 @@ class TestToTensor:
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255) input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
output = trans(img) output = trans(img)
torch.testing.assert_close(output, input_data, check_stride=False) torch.testing.assert_close(output, input_data)
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
output = trans(ndarray) output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0 expected_output = ndarray.transpose((2, 0, 1)) / 255.0
torch.testing.assert_close(output.numpy(), expected_output, check_stride=False, check_dtype=False) torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
ndarray = np.random.rand(height, width, channels).astype(np.float32) ndarray = np.random.rand(height, width, channels).astype(np.float32)
output = trans(ndarray) output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) expected_output = ndarray.transpose((2, 0, 1))
torch.testing.assert_close(output.numpy(), expected_output, check_stride=False, check_dtype=False) torch.testing.assert_close(output.numpy(), expected_output, check_dtype=False)
# separate test for mode '1' PIL images # separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_() input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1') img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img) output = trans(img)
torch.testing.assert_close(input_data, output, check_dtype=False, check_stride=False) torch.testing.assert_close(input_data, output, check_dtype=False)
def test_to_tensor_errors(self): def test_to_tensor_errors(self):
height, width = 4, 4 height, width = 4, 4
...@@ -258,7 +257,7 @@ class TestToTensor: ...@@ -258,7 +257,7 @@ class TestToTensor:
input_data = torch.ByteTensor(channels, height, width).random_(0, 255) input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
output = trans(img) output = trans(img)
torch.testing.assert_close(input_data, output, check_stride=False) torch.testing.assert_close(input_data, output)
input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
...@@ -270,13 +269,13 @@ class TestToTensor: ...@@ -270,13 +269,13 @@ class TestToTensor:
img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte()
output = trans(img) # HWC -> CHW output = trans(img) # HWC -> CHW
expected_output = (input_data * 255).byte() expected_output = (input_data * 255).byte()
torch.testing.assert_close(output, expected_output, check_stride=False) torch.testing.assert_close(output, expected_output)
# separate test for mode '1' PIL images # separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_() input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1') img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img).view(torch.uint8).bool().to(torch.uint8) output = trans(img).view(torch.uint8).bool().to(torch.uint8)
torch.testing.assert_close(input_data, output, check_stride=False) torch.testing.assert_close(input_data, output)
def test_pil_to_tensor_errors(self): def test_pil_to_tensor_errors(self):
height, width = 4, 4 height, width = 4, 4
...@@ -420,10 +419,10 @@ class TestPad: ...@@ -420,10 +419,10 @@ class TestPad:
h_padded = result[:, :padding, :] h_padded = result[:, :padding, :]
w_padded = result[:, :, :padding] w_padded = result[:, :, :padding]
torch.testing.assert_close( torch.testing.assert_close(
h_padded, torch.full_like(h_padded, fill_value=fill_v), check_stride=False, rtol=0.0, atol=eps h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps
) )
torch.testing.assert_close( torch.testing.assert_close(
w_padded, torch.full_like(w_padded, fill_value=fill_v), check_stride=False, rtol=0.0, atol=eps w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps
) )
pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)),
transforms.ToPILImage()(img)) transforms.ToPILImage()(img))
...@@ -457,7 +456,7 @@ class TestPad: ...@@ -457,7 +456,7 @@ class TestPad:
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0
edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6]
assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8), check_stride=False) assert_equal(edge_middle_slice, np.asarray([200, 200, 200, 200, 1, 0], dtype=np.uint8))
assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35)
# Pad 3 to left/right, 2 to top/bottom # Pad 3 to left/right, 2 to top/bottom
...@@ -465,7 +464,7 @@ class TestPad: ...@@ -465,7 +464,7 @@ class TestPad:
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0
reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6]
assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8), check_stride=False) assert_equal(reflect_middle_slice, np.asarray([0, 0, 1, 200, 1, 0], dtype=np.uint8))
assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35)
# Pad 3 to left, 2 to top, 2 to right, 1 to bottom # Pad 3 to left, 2 to top, 2 to right, 1 to bottom
...@@ -473,7 +472,7 @@ class TestPad: ...@@ -473,7 +472,7 @@ class TestPad:
# First 6 elements of leftmost edge in the middle of the image, values are in order: # First 6 elements of leftmost edge in the middle of the image, values are in order:
# sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0
symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6]
assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8), check_stride=False) assert_equal(symmetric_middle_slice, np.asarray([0, 1, 200, 200, 1, 0], dtype=np.uint8))
assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34) assert transforms.ToTensor()(symmetric_padded_img).size() == (3, 32, 34)
# Check negative padding explicitly for symmetric case, since it is not # Check negative padding explicitly for symmetric case, since it is not
...@@ -482,8 +481,8 @@ class TestPad: ...@@ -482,8 +481,8 @@ class TestPad:
symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric') symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric')
symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3] symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3]
symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:] symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:]
assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8), check_stride=False) assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8))
assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8), check_stride=False) assert_equal(symmetric_neg_middle_right, np.asarray([200, 200, 0, 0], dtype=np.uint8))
assert transforms.ToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31) assert transforms.ToTensor()(symmetric_padded_img_neg).size() == (3, 28, 31)
def test_pad_raises_with_invalid_pad_sequence_len(self): def test_pad_raises_with_invalid_pad_sequence_len(self):
...@@ -502,7 +501,7 @@ class TestPad: ...@@ -502,7 +501,7 @@ class TestPad:
img = Image.new("F", (10, 10)) img = Image.new("F", (10, 10))
padded_img = transform(img) padded_img = transform(img)
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False) assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size])
@pytest.mark.skipif(stats is None, reason="scipy.stats not available") @pytest.mark.skipif(stats is None, reason="scipy.stats not available")
...@@ -579,7 +578,7 @@ class TestToPil: ...@@ -579,7 +578,7 @@ class TestToPil:
img = transform(img_data) img = transform(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
torch.testing.assert_close(expected_output, to_tensor(img).numpy(), check_stride=False) torch.testing.assert_close(expected_output, to_tensor(img).numpy())
def test_1_channel_float_tensor_to_pil_image(self): def test_1_channel_float_tensor_to_pil_image(self):
img_data = torch.Tensor(1, 4, 4).uniform_() img_data = torch.Tensor(1, 4, 4).uniform_()
...@@ -617,7 +616,7 @@ class TestToPil: ...@@ -617,7 +616,7 @@ class TestToPil:
assert img.mode == expected_mode assert img.mode == expected_mode
split = img.split() split = img.split()
for i in range(2): for i in range(2):
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
def test_2_channel_ndarray_to_pil_image_error(self): def test_2_channel_ndarray_to_pil_image_error(self):
img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy()
...@@ -721,7 +720,7 @@ class TestToPil: ...@@ -721,7 +720,7 @@ class TestToPil:
assert img.mode == expected_mode assert img.mode == expected_mode
split = img.split() split = img.split()
for i in range(3): for i in range(3):
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
def test_3_channel_ndarray_to_pil_image_error(self): def test_3_channel_ndarray_to_pil_image_error(self):
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
...@@ -778,7 +777,7 @@ class TestToPil: ...@@ -778,7 +777,7 @@ class TestToPil:
assert img.mode == expected_mode assert img.mode == expected_mode
split = img.split() split = img.split()
for i in range(4): for i in range(4):
torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]), check_stride=False) torch.testing.assert_close(img_data[:, :, i], np.asarray(split[i]))
def test_4_channel_ndarray_to_pil_image_error(self): def test_4_channel_ndarray_to_pil_image_error(self):
img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy()
...@@ -1152,7 +1151,7 @@ def test_to_grayscale(): ...@@ -1152,7 +1151,7 @@ def test_to_grayscale():
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
assert_equal(gray_np, gray_np_2[:, :, 0], check_stride=False) assert_equal(gray_np, gray_np_2[:, :, 0])
# Case 3: 1 channel grayscale -> 1 channel grayscale # Case 3: 1 channel grayscale -> 1 channel grayscale
trans3 = transforms.Grayscale(num_output_channels=1) trans3 = transforms.Grayscale(num_output_channels=1)
...@@ -1170,7 +1169,7 @@ def test_to_grayscale(): ...@@ -1170,7 +1169,7 @@ def test_to_grayscale():
assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel'
assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1])
assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
assert_equal(gray_np, gray_np_4[:, :, 0], check_stride=False) assert_equal(gray_np, gray_np_4[:, :, 0])
# Checking if Grayscale can be printed as string # Checking if Grayscale can be printed as string
trans4.__repr__() trans4.__repr__()
...@@ -1240,7 +1239,7 @@ def test_random_grayscale(): ...@@ -1240,7 +1239,7 @@ def test_random_grayscale():
assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel'
assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1])
assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2])
assert_equal(gray_np, gray_np_2[:, :, 0], check_stride=False) assert_equal(gray_np, gray_np_2[:, :, 0])
# Case 3b: RGB -> 3 channel grayscale (unchanged) # Case 3b: RGB -> 3 channel grayscale (unchanged)
trans2 = transforms.RandomGrayscale(p=0.0) trans2 = transforms.RandomGrayscale(p=0.0)
...@@ -1600,8 +1599,9 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): ...@@ -1600,8 +1599,9 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
# Ensure output for PIL and Tensor are equal # Ensure output for PIL and Tensor are equal
assert_equal( assert_equal(
output_tensor, output_pil, check_stride=False, output_tensor,
msg="image_size: {} crop_size: {}".format(input_image_size, crop_size) output_pil,
msg="image_size: {} crop_size: {}".format(input_image_size, crop_size),
) )
# Check if content in center of both image and cropped output is same. # Check if content in center of both image and cropped output is same.
...@@ -1625,7 +1625,7 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): ...@@ -1625,7 +1625,7 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
input_center_tl[1]:input_center_tl[1] + center_size[1] input_center_tl[1]:input_center_tl[1] + center_size[1]
] ]
assert_equal(output_center, img_center, check_stride=False) assert_equal(output_center, img_center)
def test_color_jitter(): def test_color_jitter():
......
...@@ -18,8 +18,8 @@ from common_utils import ( ...@@ -18,8 +18,8 @@ from common_utils import (
_assert_equal_tensor_to_pil, _assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil, _assert_approx_equal_tensor_to_pil,
cpu_and_gpu, cpu_and_gpu,
assert_equal,
) )
from _assert_utils import assert_equal
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
......
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import random import random
import numpy as np import numpy as np
import warnings import warnings
from _assert_utils import assert_equal from common_utils import assert_equal
try: try:
from scipy import stats from scipy import stats
......
...@@ -9,7 +9,7 @@ import torchvision.utils as utils ...@@ -9,7 +9,7 @@ import torchvision.utils as utils
from io import BytesIO from io import BytesIO
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image, __version__ as PILLOW_VERSION, ImageColor from PIL import Image, __version__ as PILLOW_VERSION, ImageColor
from _assert_utils import assert_equal from common_utils import assert_equal
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))
......
...@@ -10,8 +10,7 @@ import torchvision.io as io ...@@ -10,8 +10,7 @@ import torchvision.io as io
from numpy.random import randint from numpy.random import randint
from torchvision import set_video_backend from torchvision import set_video_backend
from torchvision.io import _HAS_VIDEO_OPT from torchvision.io import _HAS_VIDEO_OPT
from common_utils import PY39_SKIP from common_utils import PY39_SKIP, assert_equal
from _assert_utils import assert_equal
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment