Commit 144ca427 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Change AffineTransformation back to LinearTransformation (#843)

* Update test_transforms.py

* Update transforms.py
parent c2bfa661
...@@ -953,28 +953,6 @@ class Tester(unittest.TestCase): ...@@ -953,28 +953,6 @@ class Tester(unittest.TestCase):
color_jitter.__repr__() color_jitter.__repr__()
def test_linear_transformation(self): def test_linear_transformation(self):
x = torch.randn(250, 10, 10, 3)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
# compute principal components
sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
u, s, _ = np.linalg.svd(sigma.numpy())
zca_epsilon = 1e-10 # avoid division by 0
d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
u = torch.Tensor(u)
principal_components = torch.mm(torch.mm(u, d), u.t())
# initialize whitening matrix
whitening = transforms.LinearTransformation(principal_components)
# pass first vector
xwhite = whitening(x[0].view(10, 10, 3))
# estimate covariance
xwhite = xwhite.view(1, 300).numpy()
cov = np.dot(xwhite, xwhite.T) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3)
# Checking if LinearTransformation can be printed as string
whitening.__repr__()
def test_affine_transformation(self):
num_samples = 1000 num_samples = 1000
x = torch.randn(num_samples, 3, 10, 10) x = torch.randn(num_samples, 3, 10, 10)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3)) flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
...@@ -987,7 +965,7 @@ class Tester(unittest.TestCase): ...@@ -987,7 +965,7 @@ class Tester(unittest.TestCase):
principal_components = torch.mm(torch.mm(u, d), u.t()) principal_components = torch.mm(torch.mm(u, d), u.t())
mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0)) mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
# initialize whitening matrix # initialize whitening matrix
whitening = transforms.AffineTransformation(principal_components, mean_vector) whitening = transforms.LinearTransformation(principal_components, mean_vector)
# estimate covariance and mean using weak law of large number # estimate covariance and mean using weak law of large number
num_features = flat_x.size(1) num_features = flat_x.size(1)
cov = 0.0 cov = 0.0
...@@ -1001,7 +979,7 @@ class Tester(unittest.TestCase): ...@@ -1001,7 +979,7 @@ class Tester(unittest.TestCase):
assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1" assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1"
assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0" assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0"
# Checking if AffineTransformation can be printed as string # Checking if LinearTransformation can be printed as string
whitening.__repr__() whitening.__repr__()
def test_rotate(self): def test_rotate(self):
......
...@@ -27,7 +27,7 @@ else: ...@@ -27,7 +27,7 @@ else:
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"AffineTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"] "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = { _pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST', Image.NEAREST: 'PIL.Image.NEAREST',
...@@ -710,7 +710,7 @@ class TenCrop(object): ...@@ -710,7 +710,7 @@ class TenCrop(object):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class AffineTransformation(object): class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed """Transform a tensor image with a square transformation matrix and a mean_vector computed
offline. offline.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
...@@ -763,17 +763,6 @@ class AffineTransformation(object): ...@@ -763,17 +763,6 @@ class AffineTransformation(object):
return format_string return format_string
class LinearTransformation(AffineTransformation):
"""
Note: This transform is deprecated in favor of AffineTransformation.
"""
def __init__(self, transformation_matrix):
warnings.warn("The use of the transforms.LinearTransformation transform is deprecated, " +
"please use transforms.AffineTransformation instead.")
super(LinearTransformation, self).__init__(transformation_matrix, torch.zeros_like(transformation_matrix[0]))
class ColorJitter(object): class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image. """Randomly change the brightness, contrast and saturation of an image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment