Unverified Commit be8192e2 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Improve transforms test codebase (#2620)

* Improve transforms test codebase
- refactored compareTensorToPIL, _create_data, approxEqualTensorToPIL methods

* Fixed flake8
parent fc69c225
......@@ -14,6 +14,9 @@ from numbers import Number
from torch._six import string_classes
from collections import OrderedDict
import numpy as np
from PIL import Image
@contextlib.contextmanager
def get_tmp_dir(src=None, **kwargs):
......@@ -329,3 +332,28 @@ def freeze_rng_state():
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(rng_state)
class TransformsTester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def compareTensorToPIL(self, tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
err = getattr(torch, method)(tensor - pil_tensor).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
import unittest
import random
import colorsys
import math
from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import numpy as np
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import torch
import torchvision.transforms as transforms
......@@ -14,27 +11,10 @@ import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
from common_utils import TransformsTester
class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img
def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
mae = torch.abs(tensor - pil_tensor).mean().item()
self.assertTrue(
mae < tol,
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
class Tester(TransformsTester):
def _test_vflip(self, device):
script_vflip = torch.jit.script(F_t.vflip)
......
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import numpy as np
import unittest
from common_utils import TransformsTester
class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
return tensor, pil_img
def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))
class Tester(TransformsTester):
def _test_functional_geom_op(self, func, fn_kwargs):
if fn_kwargs is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment