Unverified Commit 1b415254 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Normalize, LinearTransformation are scriptable (#2645)

* [WIP] All transforms are now derived from torch.nn.Module
- Compose, RandomApply, Normalize can be jit scripted

* Fixed flake8

* Updated code and docs
- added getattr to Lambda and tests
- updated code and docs of Compose
- added failing test with append/extend on Composed.transforms

* Fixed flake8

* Updated code, tests and docs
parent 8dfcff74
...@@ -14,6 +14,26 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as ...@@ -14,6 +14,26 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as
Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or
random transformations applied on the batch of Tensor Images identically transform all the images of the batch. random transformations applied on the batch of Tensor Images identically transform all the images of the batch.
Scriptable transforms
^^^^^^^^^^^^^^^^^^^^^
In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`.
.. code:: python
transforms = torch.nn.Sequential(
transforms.CenterCrop(10),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
)
scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require
`lambda` functions or ``PIL.Image``.
For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``.
.. autoclass:: Compose .. autoclass:: Compose
Transforms on PIL Image Transforms on PIL Image
......
...@@ -376,6 +376,63 @@ class Tester(TransformsTester): ...@@ -376,6 +376,63 @@ class Tester(TransformsTester):
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
) )
def test_normalize(self):
tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
# test for class interface
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
def test_linear_transformation(self):
c, h, w = 3, 24, 32
tensor, _ = self._create_data(h, w, channels=c, device=self.device)
matrix = torch.rand(c * h * w, c * h * w, device=self.device)
mean_vector = torch.rand(c * h * w, device=self.device)
fn = T.LinearTransformation(matrix, mean_vector)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
batch_tensors = torch.rand(4, c, h, w, device=self.device)
# We skip some tests from _test_transform_vs_scripted_on_batch as
# results for scripted and non-scripted transformations are not exactly the same
torch.manual_seed(12)
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
def test_compose(self):
tensor, _ = self._create_data(26, 34, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0
transforms = T.Compose([
T.CenterCrop(10),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
s_transforms = torch.nn.Sequential(*transforms.transforms)
scripted_fn = torch.jit.script(s_transforms)
torch.manual_seed(12)
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
t = T.Compose([
lambda x: x,
])
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
...@@ -283,7 +283,7 @@ def to_pil_image(pic, mode=None): ...@@ -283,7 +283,7 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=mode)
def normalize(tensor, mean, std, inplace=False): def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
"""Normalize a tensor image with mean and standard deviation. """Normalize a tensor image with mean and standard deviation.
.. note:: .. note::
...@@ -292,7 +292,7 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -292,7 +292,7 @@ def normalize(tensor, mean, std, inplace=False):
See :class:`~torchvision.transforms.Normalize` for more details. See :class:`~torchvision.transforms.Normalize` for more details.
Args: Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel. mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel. std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace. inplace(bool,optional): Bool to make this operation inplace.
...@@ -300,11 +300,11 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -300,11 +300,11 @@ def normalize(tensor, mean, std, inplace=False):
Returns: Returns:
Tensor: Normalized Tensor image. Tensor: Normalized Tensor image.
""" """
if not torch.is_tensor(tensor): if not isinstance(tensor, torch.Tensor):
raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor))) raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if tensor.ndimension() != 3: if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = ' raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size())) '{}.'.format(tensor.size()))
if not inplace: if not inplace:
...@@ -316,9 +316,9 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -316,9 +316,9 @@ def normalize(tensor, mean, std, inplace=False):
if (std == 0).any(): if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1: if mean.ndim == 1:
mean = mean[:, None, None] mean = mean.view(-1, 1, 1)
if std.ndim == 1: if std.ndim == 1:
std = std[:, None, None] std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std) tensor.sub_(mean).div_(std)
return tensor return tensor
......
...@@ -3,7 +3,7 @@ import numbers ...@@ -3,7 +3,7 @@ import numbers
import random import random
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Tuple, List, Optional from typing import Tuple, List, Optional, Any
import torch import torch
from PIL import Image from PIL import Image
...@@ -33,7 +33,7 @@ _pil_interpolation_to_str = { ...@@ -33,7 +33,7 @@ _pil_interpolation_to_str = {
} }
class Compose(object): class Compose:
"""Composes several transforms together. """Composes several transforms together.
Args: Args:
...@@ -44,6 +44,19 @@ class Compose(object): ...@@ -44,6 +44,19 @@ class Compose(object):
>>> transforms.CenterCrop(10), >>> transforms.CenterCrop(10),
>>> transforms.ToTensor(), >>> transforms.ToTensor(),
>>> ]) >>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
""" """
def __init__(self, transforms): def __init__(self, transforms):
...@@ -63,7 +76,7 @@ class Compose(object): ...@@ -63,7 +76,7 @@ class Compose(object):
return format_string return format_string
class ToTensor(object): class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range Converts a PIL Image or numpy.ndarray (H x W x C) in the range
...@@ -94,7 +107,7 @@ class ToTensor(object): ...@@ -94,7 +107,7 @@ class ToTensor(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class PILToTensor(object): class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type. """Convert a ``PIL Image`` to a tensor of the same type.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
...@@ -114,7 +127,7 @@ class PILToTensor(object): ...@@ -114,7 +127,7 @@ class PILToTensor(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class ConvertImageDtype(object): class ConvertImageDtype:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly """Convert a tensor image to the given ``dtype`` and scale the values accordingly
Args: Args:
...@@ -139,7 +152,7 @@ class ConvertImageDtype(object): ...@@ -139,7 +152,7 @@ class ConvertImageDtype(object):
return F.convert_image_dtype(image, self.dtype) return F.convert_image_dtype(image, self.dtype)
class ToPILImage(object): class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image. """Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
...@@ -178,7 +191,7 @@ class ToPILImage(object): ...@@ -178,7 +191,7 @@ class ToPILImage(object):
return format_string return format_string
class Normalize(object): class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation. """Normalize a tensor image with mean and standard deviation.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input channels, this transform will normalize each channel of the input
...@@ -196,11 +209,12 @@ class Normalize(object): ...@@ -196,11 +209,12 @@ class Normalize(object):
""" """
def __init__(self, mean, std, inplace=False): def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean self.mean = mean
self.std = std self.std = std
self.inplace = inplace self.inplace = inplace
def __call__(self, tensor): def forward(self, tensor: Tensor) -> Tensor:
""" """
Args: Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
...@@ -358,7 +372,7 @@ class Pad(torch.nn.Module): ...@@ -358,7 +372,7 @@ class Pad(torch.nn.Module):
format(self.padding, self.fill, self.padding_mode) format(self.padding, self.fill, self.padding_mode)
class Lambda(object): class Lambda:
"""Apply a user-defined lambda as a transform. """Apply a user-defined lambda as a transform.
Args: Args:
...@@ -366,7 +380,8 @@ class Lambda(object): ...@@ -366,7 +380,8 @@ class Lambda(object):
""" """
def __init__(self, lambd): def __init__(self, lambd):
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" if not callable(lambd):
raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
self.lambd = lambd self.lambd = lambd
def __call__(self, img): def __call__(self, img):
...@@ -376,7 +391,7 @@ class Lambda(object): ...@@ -376,7 +391,7 @@ class Lambda(object):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '()'
class RandomTransforms(object): class RandomTransforms:
"""Base class for a list of transformations with randomness """Base class for a list of transformations with randomness
Args: Args:
...@@ -408,7 +423,7 @@ class RandomApply(RandomTransforms): ...@@ -408,7 +423,7 @@ class RandomApply(RandomTransforms):
""" """
def __init__(self, transforms, p=0.5): def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms) super().__init__(transforms)
self.p = p self.p = p
def __call__(self, img): def __call__(self, img):
...@@ -897,7 +912,7 @@ class TenCrop(torch.nn.Module): ...@@ -897,7 +912,7 @@ class TenCrop(torch.nn.Module):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class LinearTransformation(object): class LinearTransformation(torch.nn.Module):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed """Transform a tensor image with a square transformation matrix and a mean_vector computed
offline. offline.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
...@@ -916,6 +931,7 @@ class LinearTransformation(object): ...@@ -916,6 +931,7 @@ class LinearTransformation(object):
""" """
def __init__(self, transformation_matrix, mean_vector): def __init__(self, transformation_matrix, mean_vector):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1): if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " + raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
...@@ -925,10 +941,14 @@ class LinearTransformation(object): ...@@ -925,10 +941,14 @@ class LinearTransformation(object):
" as any one of the dimensions of the transformation_matrix [{}]" " as any one of the dimensions of the transformation_matrix [{}]"
.format(tuple(transformation_matrix.size()))) .format(tuple(transformation_matrix.size())))
if transformation_matrix.device != mean_vector.device:
raise ValueError("Input tensors should be on the same device. Got {} and {}"
.format(transformation_matrix.device, mean_vector.device))
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
def __call__(self, tensor): def forward(self, tensor: Tensor) -> Tensor:
""" """
Args: Args:
tensor (Tensor): Tensor image of size (C, H, W) to be whitened. tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
...@@ -936,13 +956,20 @@ class LinearTransformation(object): ...@@ -936,13 +956,20 @@ class LinearTransformation(object):
Returns: Returns:
Tensor: Transformed image. Tensor: Transformed image.
""" """
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): shape = tensor.shape
raise ValueError("tensor and transformation matrix have incompatible shape." + n = shape[-3] * shape[-2] * shape[-1]
"[{} x {} x {}] != ".format(*tensor.size()) + if n != self.transformation_matrix.shape[0]:
"{}".format(self.transformation_matrix.size(0))) raise ValueError("Input tensor and transformation matrix have incompatible shape." +
flat_tensor = tensor.view(1, -1) - self.mean_vector "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
"{}".format(self.transformation_matrix.shape[0]))
if tensor.device.type != self.mean_vector.device.type:
raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device))
flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size()) tensor = transformed_tensor.view(shape)
return tensor return tensor
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment