Commit c88d7fb5 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Add AffineTransformation (#793)

* Add Affinetransformation

Add Affinetransformation to superseed LinearTransformation

* Add test

* Add zero mean_vector in LinearTransformation and improved docs

* update

* minor fix

* minor fix2

* fixed flake8

* fix flake8

* fixed transpose syntax

* fixed shape of mean_vector in test

* fixed test

* print est cov and mean

* fixed flake8

* debug

* reduce num_samples

* debug

* fixed num_features

* fixed rtol for cov

* fix __repr__

* Update transforms.py

* Update test_transforms.py

* Update transforms.py

* fix flake8

* Update transforms.py

* Update transforms.py

* Update transforms.py

* Update transforms.py

* Changed dim of mean_vector to 1D, doc and removed .numpy () from format_string

* Restore test_linear_transformation()

* Update test_transforms.py
parent 1fa2f866
......@@ -974,6 +974,36 @@ class Tester(unittest.TestCase):
# Checking if LinearTransformation can be printed as string
whitening.__repr__()
def test_affine_transformation(self):
num_samples = 1000
x = torch.randn(num_samples, 3, 10, 10)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
# compute principal components
sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
u, s, _ = np.linalg.svd(sigma.numpy())
zca_epsilon = 1e-10 # avoid division by 0
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t())
mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
# initialize whitening matrix
whitening = transforms.AffineTransformation(principal_components, mean_vector)
# estimate covariance and mean using weak law of large number
num_features = flat_x.size(1)
cov = 0.0
mean = 0.0
for i in x:
xwhite = whitening(i)
xwhite = xwhite.view(1, -1).numpy()
cov += np.dot(xwhite, xwhite.T) / num_features
mean += np.sum(xwhite) / num_features
# if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1"
assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0"
# Checking if AffineTransformation can be printed as string
whitening.__repr__()
def test_rotate(self):
x = np.zeros((100, 100, 3), dtype=np.uint8)
x[40, 40] = [255, 255, 255]
......
......@@ -27,7 +27,7 @@ else:
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
"AffineTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
......@@ -710,28 +710,34 @@ class TenCrop(object):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix computed
class AffineTransformation(object):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.
Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
product with the transformation matrix and reshape the tensor to its
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and then reshaping the tensor to its
original shape.
Applications:
- whitening: zero-center the data, compute the data covariance matrix
[D x D] with np.dot(X.T, X), perform SVD on this matrix and
pass it as transformation_matrix.
- whitening transformation: Suppose X is a column vector zero-centered data.
Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
perform SVD on this matrix and pass it as transformation_matrix.
Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [D], D = C x H x W
"""
def __init__(self, transformation_matrix):
def __init__(self, transformation_matrix, mean_vector):
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
" as any one of the dimensions of the transformation_matrix [{} x {}]"
.format(transformation_matrix.size()))
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def __call__(self, tensor):
"""
......@@ -745,17 +751,29 @@ class LinearTransformation(object):
raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1)
flat_tensor = tensor.view(1, -1) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size())
return tensor
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
format_string = self.__class__.__name__ + '(transformation_matrix='
format_string += (str(self.transformation_matrix.tolist()) + ')')
format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
return format_string
class LinearTransformation(AffineTransformation):
"""
Note: This transform is deprecated in favor of AffineTransformation.
"""
def __init__(self, transformation_matrix):
warnings.warn("The use of the transforms.LinearTransformation transform is deprecated, " +
"please use transforms.AffineTransformation instead.")
super(LinearTransformation, self).__init__(transformation_matrix, torch.zeros_like(transformation_matrix[0]))
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment