Commit cee28367 authored by Iacopo Poli's avatar Iacopo Poli Committed by Francisco Massa
Browse files

Added Whiten transform and its section in README (#263)

* added Whiten transform and relative section in README

* added Whiten transform and its section in README

* fixed lint

* fixed lint again

* added linear transformation and whitening test unitary covariance

* fixed lint and test

* commit before rebasing

* before rebase

* added assertion to check matrix size

* fix readme

* added space around operator

* replace assert with raise ValueError, fix spaces and readme
parent 5b433d83
......@@ -377,6 +377,18 @@ Transforms on torch.\*Tensor
Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of
the torch.\*Tensor, i.e. channel = (channel - mean) / std
``LinearTransformation(transformation_matrix)``
^^^^^^^^^^^^^^^^^^^^^^^^
Given ``transformation_matrix`` (D x D), where D = (C x H x W), will compute its
dot product with the flattened torch.\*Tensor and then reshape it to its
original dimensions.
Applications:
- whitening: zero-center the data, compute the data covariance matrix [D x D] with
np.dot(X.T, X), perform SVD on this matrix and pass the principal components as
transformation_matrix.
Conversion Transforms
~~~~~~~~~~~~~~~~~~~~~
......
......@@ -591,6 +591,25 @@ class Tester(unittest.TestCase):
y_pil_2 = color_jitter(x_pil_2)
assert y_pil_2.mode == x_pil_2.mode
def test_linear_transformation(self):
x = torch.randn(250, 10, 10, 3)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
# compute principal components
sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
u, s, _ = np.linalg.svd(sigma.numpy())
zca_epsilon = 1e-10 # avoid division by 0
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t())
# initialize whitening matrix
whitening = transforms.LinearTransformation(principal_components)
# pass first vector
xwhite = whitening(x[0].view(10, 10, 3))
# estimate covariance
xwhite = xwhite.view(1, 300).numpy()
cov = np.dot(xwhite, xwhite.T) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3)
if __name__ == '__main__':
unittest.main()
......@@ -918,6 +918,47 @@ class TenCrop(object):
return ten_crop(img, self.size, self.vertical_flip)
class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix computed
offline.
Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
product with the transformation matrix and reshape the tensor to its
original shape.
Applications:
- whitening: zero-center the data, compute the data covariance matrix
[D x D] with np.dot(X.T, X), perform SVD on this matrix and
pass it as transformation_matrix.
Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
"""
def __init__(self, transformation_matrix):
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
self.transformation_matrix = transformation_matrix
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be whitened.
Returns:
Tensor: Transformed image.
"""
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
raise ValueError("tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(*tensor.size()) +
"{}".format(self.transformation_matrix.size(0)))
flat_tensor = tensor.view(1, -1)
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(tensor.size())
return tensor
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment