Commit c88d7fb5 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Add AffineTransformation (#793)

* Add Affinetransformation

Add Affinetransformation to superseed LinearTransformation

* Add test

* Add zero mean_vector in LinearTransformation and improved docs

* update

* minor fix

* minor fix2

* fixed flake8

* fix flake8

* fixed transpose syntax

* fixed shape of mean_vector in test

* fixed test

* print est cov and mean

* fixed flake8

* debug

* reduce num_samples

* debug

* fixed num_features

* fixed rtol for cov

* fix __repr__

* Update transforms.py

* Update test_transforms.py

* Update transforms.py

* fix flake8

* Update transforms.py

* Update transforms.py

* Update transforms.py

* Update transforms.py

* Changed dim of mean_vector to 1D, doc and removed .numpy () from format_string

* Restore test_linear_transformation()

* Update test_transforms.py
parent 1fa2f866
...@@ -974,6 +974,36 @@ class Tester(unittest.TestCase): ...@@ -974,6 +974,36 @@ class Tester(unittest.TestCase):
# Checking if LinearTransformation can be printed as string # Checking if LinearTransformation can be printed as string
whitening.__repr__() whitening.__repr__()
def test_affine_transformation(self):
num_samples = 1000
x = torch.randn(num_samples, 3, 10, 10)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
# compute principal components
sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
u, s, _ = np.linalg.svd(sigma.numpy())
zca_epsilon = 1e-10 # avoid division by 0
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t())
mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
# initialize whitening matrix
whitening = transforms.AffineTransformation(principal_components, mean_vector)
# estimate covariance and mean using weak law of large number
num_features = flat_x.size(1)
cov = 0.0
mean = 0.0
for i in x:
xwhite = whitening(i)
xwhite = xwhite.view(1, -1).numpy()
cov += np.dot(xwhite, xwhite.T) / num_features
mean += np.sum(xwhite) / num_features
# if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1"
assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0"
# Checking if AffineTransformation can be printed as string
whitening.__repr__()
def test_rotate(self): def test_rotate(self):
x = np.zeros((100, 100, 3), dtype=np.uint8) x = np.zeros((100, 100, 3), dtype=np.uint8)
x[40, 40] = [255, 255, 255] x[40, 40] = [255, 255, 255]
......
...@@ -27,7 +27,7 @@ else: ...@@ -27,7 +27,7 @@ else:
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"] "AffineTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = { _pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST', Image.NEAREST: 'PIL.Image.NEAREST',
...@@ -710,28 +710,34 @@ class TenCrop(object): ...@@ -710,28 +710,34 @@ class TenCrop(object):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class LinearTransformation(object): class AffineTransformation(object):
"""Transform a tensor image with a square transformation matrix computed """Transform a tensor image with a square transformation matrix and a mean_vector computed
offline. offline.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
Given transformation_matrix, will flatten the torch.*Tensor, compute the dot subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and reshape the tensor to its product with the transformation matrix and then reshaping the tensor to its
original shape. original shape.
Applications: Applications:
- whitening: zero-center the data, compute the data covariance matrix - whitening transformation: Suppose X is a column vector zero-centered data.
[D x D] with np.dot(X.T, X), perform SVD on this matrix and Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
pass it as transformation_matrix. perform SVD on this matrix and pass it as transformation_matrix.
Args: Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [D], D = C x H x W
""" """
def __init__(self, transformation_matrix): def __init__(self, transformation_matrix, mean_vector):
if transformation_matrix.size(0) != transformation_matrix.size(1): if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " + raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
" as any one of the dimensions of the transformation_matrix [{} x {}]"
.format(transformation_matrix.size()))
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def __call__(self, tensor): def __call__(self, tensor):
""" """
...@@ -745,17 +751,29 @@ class LinearTransformation(object): ...@@ -745,17 +751,29 @@ class LinearTransformation(object):
raise ValueError("tensor and transformation matrix have incompatible shape." + raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) + "[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0))) "{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1) flat_tensor = tensor.view(1, -1) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size()) tensor = transformed_tensor.view(tensor.size())
return tensor return tensor
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + '(transformation_matrix='
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')') format_string += (str(self.transformation_matrix.tolist()) + ')')
format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
return format_string return format_string
class LinearTransformation(AffineTransformation):
"""
Note: This transform is deprecated in favor of AffineTransformation.
"""
def __init__(self, transformation_matrix):
warnings.warn("The use of the transforms.LinearTransformation transform is deprecated, " +
"please use transforms.AffineTransformation instead.")
super(LinearTransformation, self).__init__(transformation_matrix, torch.zeros_like(transformation_matrix[0]))
class ColorJitter(object): class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image. """Randomly change the brightness, contrast and saturation of an image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment