Unverified Commit 1b415254 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Normalize, LinearTransformation are scriptable (#2645)

* [WIP] All transforms are now derived from torch.nn.Module
- Compose, RandomApply, Normalize can be jit scripted

* Fixed flake8

* Updated code and docs
- added getattr to Lambda and tests
- updated code and docs of Compose
- added failing test with append/extend on Composed.transforms

* Fixed flake8

* Updated code, tests and docs
parent 8dfcff74
......@@ -14,6 +14,26 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
Scriptable transforms
^^^^^^^^^^^^^^^^^^^^^
In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.
.. code:: python
transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require
`lambda` functions or ``PIL.Image``.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
.. autoclass:: Compose
Transforms on PIL Image
......
......@@ -376,6 +376,63 @@ class Tester(TransformsTester):
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
def test_normalize(self):
tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
# test for class interface
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
def test_linear_transformation(self):
c, h, w = 3, 24, 32
tensor, _ = self._create_data(h, w, channels=c, device=self.device)
matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device)
fn = T.LinearTransformation(matrix, mean_vector)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
batch_tensors = torch.rand(4, c, h, w, device=self.device)
# We skip some tests from _test_transform_vs_scripted_on_batch as
# results for scripted and non-scripted transformations are not exactly the same
torch.manual_seed(12)
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([
T.CenterCrop(10),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
s_transforms = torch.nn.Sequential(*transforms.transforms)
scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12)
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
t = T.Compose([
lambda x: x,
])
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
......
......@@ -283,7 +283,7 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode)
def normalize(tensor, mean, std, inplace=False):
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
"""Normalize a tensor image with mean and standard deviation.
.. note::
......@@ -292,7 +292,7 @@ def normalize(tensor, mean, std, inplace=False):
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
......@@ -300,11 +300,11 @@ def normalize(tensor, mean, std, inplace=False):
Returns:
Tensor: Normalized Tensor image.
"""
if not torch.is_tensor(tensor):
raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if tensor.ndimension() != 3:
raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = '
if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))
if not inplace:
......@@ -316,9 +316,9 @@ def normalize(tensor, mean, std, inplace=False):
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean[:, None, None]
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std[:, None, None]
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor
......
......@@ -3,7 +3,7 @@ import numbers
import random
import warnings
from collections.abc import Sequence
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional, Any
import torch
from PIL import Image
......@@ -33,7 +33,7 @@ _pil_interpolation_to_str = {
}
class Compose(object):
class Compose:
"""Composes several transforms together.
Args:
......@@ -44,6 +44,19 @@ class Compose(object):
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
......@@ -63,7 +76,7 @@ class Compose(object):
return format_string
class ToTensor(object):
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
......@@ -94,7 +107,7 @@ class ToTensor(object):
return self.__class__.__name__ + '()'
class PILToTensor(object):
class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
......@@ -114,7 +127,7 @@ class PILToTensor(object):
return self.__class__.__name__ + '()'
class ConvertImageDtype(object):
class ConvertImageDtype:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
Args:
......@@ -139,7 +152,7 @@ class ConvertImageDtype(object):
return F.convert_image_dtype(image, self.dtype)
class ToPILImage(object):
class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
......@@ -178,7 +191,7 @@ class ToPILImage(object):
return format_string
class Normalize(object):
class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
......@@ -196,11 +209,12 @@ class Normalize(object):
"""
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, tensor):
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
......@@ -358,7 +372,7 @@ class Pad(torch.nn.Module):
format(self.padding, self.fill, self.padding_mode)
class Lambda(object):
class Lambda:
"""Apply a user-defined lambda as a transform.
Args:
......@@ -366,7 +380,8 @@ class Lambda(object):
"""
def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
if not callable(lambd):
raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
self.lambd = lambd
def __call__(self, img):
......@@ -376,7 +391,7 @@ class Lambda(object):
return self.__class__.__name__ + '()'
class RandomTransforms(object):
class RandomTransforms:
"""Base class for a list of transformations with randomness
Args:
......@@ -408,7 +423,7 @@ class RandomApply(RandomTransforms):
"""
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
super().__init__(transforms)
self.p = p
def __call__(self, img):
......@@ -897,7 +912,7 @@ class TenCrop(torch.nn.Module):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class LinearTransformation(object):
class LinearTransformation(torch.nn.Module):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
......@@ -916,6 +931,7 @@ class LinearTransformation(object):
"""
def __init__(self, transformation_matrix, mean_vector):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
......@@ -925,10 +941,14 @@ class LinearTransformation(object):
" as any one of the dimensions of the transformation_matrix [{}]"
.format(tuple(transformation_matrix.size())))
if transformation_matrix.device != mean_vector.device:
raise ValueError("Input tensors should be on the same device. Got {} and {}"
.format(transformation_matrix.device, mean_vector.device))
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def __call__(self, tensor):
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
......@@ -936,13 +956,20 @@ class LinearTransformation(object):
Returns:
Tensor: Transformed image.
"""
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1) - self.mean_vector
shape = tensor.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError("Input tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
"{}".format(self.transformation_matrix.shape[0]))
if tensor.device.type != self.mean_vector.device.type:
raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device))
flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size())
tensor = transformed_tensor.view(shape)
return tensor
def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment