Unverified Commit e7c0abac authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Removed type from exception error (#2729)

Otherwise, torch jit scripted function raises exception on save
parent 53ccd538
import os
import torch import torch
from torchvision import transforms as T from torchvision import transforms as T
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
...@@ -8,7 +9,7 @@ import numpy as np ...@@ -8,7 +9,7 @@ import numpy as np
import unittest import unittest
from common_utils import TransformsTester from common_utils import TransformsTester, get_tmp_dir
class Tester(TransformsTester): class Tester(TransformsTester):
...@@ -73,6 +74,9 @@ class Tester(TransformsTester): ...@@ -73,6 +74,9 @@ class Tester(TransformsTester):
batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors) self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method)))
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None): def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs) self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs) self._test_class_op(method, meth_kwargs)
...@@ -188,6 +192,9 @@ class Tester(TransformsTester): ...@@ -188,6 +192,9 @@ class Tester(TransformsTester):
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
scripted_fn(tensor) scripted_fn(tensor)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None: if fn_kwargs is None:
fn_kwargs = {} fn_kwargs = {}
...@@ -231,6 +238,9 @@ class Tester(TransformsTester): ...@@ -231,6 +238,9 @@ class Tester(TransformsTester):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...])) msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
def test_five_crop(self): def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)} fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output( self._test_op_list_output(
...@@ -294,6 +304,9 @@ class Tester(TransformsTester): ...@@ -294,6 +304,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(transform, s_transform, tensor) self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
script_fn.save(os.path.join(tmp_dir, "t_resize.pt"))
def test_resized_crop(self): def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
...@@ -309,6 +322,9 @@ class Tester(TransformsTester): ...@@ -309,6 +322,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(transform, s_transform, tensor) self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))
def test_random_affine(self): def test_random_affine(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
...@@ -327,6 +343,9 @@ class Tester(TransformsTester): ...@@ -327,6 +343,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(transform, s_transform, tensor) self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))
def test_random_rotate(self): def test_random_rotate(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
...@@ -343,6 +362,9 @@ class Tester(TransformsTester): ...@@ -343,6 +362,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(transform, s_transform, tensor) self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
def test_random_perspective(self): def test_random_perspective(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
...@@ -358,6 +380,9 @@ class Tester(TransformsTester): ...@@ -358,6 +380,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(transform, s_transform, tensor) self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
def test_to_grayscale(self): def test_to_grayscale(self):
meth_kwargs = {"num_output_channels": 1} meth_kwargs = {"num_output_channels": 1}
...@@ -388,6 +413,9 @@ class Tester(TransformsTester): ...@@ -388,6 +413,9 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(fn, scripted_fn, tensor) self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_linear_transformation(self): def test_linear_transformation(self):
c, h, w = 3, 24, 32 c, h, w = 3, 24, 32
...@@ -410,6 +438,9 @@ class Tester(TransformsTester): ...@@ -410,6 +438,9 @@ class Tester(TransformsTester):
s_transformed_batch = scripted_fn(batch_tensors) s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch)) self.assertTrue(transformed_batch.equal(s_transformed_batch))
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
def test_compose(self): def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0 tensor = tensor.to(dtype=torch.float32) / 255.0
......
...@@ -15,7 +15,7 @@ def _get_image_size(img: Tensor) -> List[int]: ...@@ -15,7 +15,7 @@ def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image""" """Returns (w, h) of tensor image"""
if _is_tensor_a_torch_image(img): if _is_tensor_a_torch_image(img):
return [img.shape[-1], img.shape[-2]] return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected input type")
def _get_image_num_channels(img: Tensor) -> int: def _get_image_num_channels(img: Tensor) -> int:
...@@ -24,7 +24,7 @@ def _get_image_num_channels(img: Tensor) -> int: ...@@ -24,7 +24,7 @@ def _get_image_num_channels(img: Tensor) -> int:
elif img.ndim > 2: elif img.ndim > 2:
return img.shape[-3] return img.shape[-3]
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim))
def vflip(img: Tensor) -> Tensor: def vflip(img: Tensor) -> Tensor:
...@@ -223,7 +223,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -223,7 +223,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor image. Got {}'.format(type(img))) raise TypeError('Input img should be Tensor image')
orig_dtype = img.dtype orig_dtype = img.dtype
if img.dtype == torch.uint8: if img.dtype == torch.uint8:
...@@ -294,7 +294,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: ...@@ -294,7 +294,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
raise TypeError('img should be a Tensor. Got {}'.format(type(img))) raise TypeError('Input img should be a Tensor.')
if gamma < 0: if gamma < 0:
raise ValueError('Gamma should be a non-negative real number') raise ValueError('Gamma should be a non-negative real number')
...@@ -763,10 +763,10 @@ def _assert_grid_transform_inputs( ...@@ -763,10 +763,10 @@ def _assert_grid_transform_inputs(
coeffs: Optional[List[float]] = None, coeffs: Optional[List[float]] = None,
): ):
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError("img should be Tensor Image. Got {}".format(type(img))) raise TypeError("Input img should be Tensor Image")
if matrix is not None and not isinstance(matrix, list): if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix))) raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6: if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values") raise ValueError("Argument matrix should have 6 float values")
...@@ -989,7 +989,7 @@ def perspective( ...@@ -989,7 +989,7 @@ def perspective(
Tensor: transformed image. Tensor: transformed image.
""" """
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) raise TypeError('Input img should be Tensor Image')
_interpolation_modes = { _interpolation_modes = {
0: "nearest", 0: "nearest",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment