Commit e56670db authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #51 from alykhantejani/fix_test_transforms

Use integer division in tests/test_transforms for array slice indices + various PEP-8 fixes
parents 6c7733f0 1610268e
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import unittest import unittest
import random import random
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_crop(self): def test_crop(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
...@@ -13,9 +12,9 @@ class Tester(unittest.TestCase): ...@@ -13,9 +12,9 @@ class Tester(unittest.TestCase):
owidth = random.randint(5, (width - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
oh1 = (height - oheight) / 2 oh1 = (height - oheight) // 2
ow1 = (width - owidth) / 2 ow1 = (width - owidth) // 2
imgnarrow = img[:, oh1 :oh1 + oheight, ow1 :ow1 + owidth] imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
imgnarrow.fill_(0) imgnarrow.fill_(0)
result = transforms.Compose([ result = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
...@@ -23,7 +22,7 @@ class Tester(unittest.TestCase): ...@@ -23,7 +22,7 @@ class Tester(unittest.TestCase):
transforms.ToTensor(), transforms.ToTensor(),
])(img) ])(img)
assert result.sum() == 0, "height: " + str(height) + " width: " \ assert result.sum() == 0, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = transforms.Compose([ result = transforms.Compose([
...@@ -33,7 +32,7 @@ class Tester(unittest.TestCase): ...@@ -33,7 +32,7 @@ class Tester(unittest.TestCase):
])(img) ])(img)
sum1 = result.sum() sum1 = result.sum()
assert sum1 > 1, "height: " + str(height) + " width: " \ assert sum1 > 1, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1 oheight += 1
owidth += 1 owidth += 1
result = transforms.Compose([ result = transforms.Compose([
...@@ -43,9 +42,9 @@ class Tester(unittest.TestCase): ...@@ -43,9 +42,9 @@ class Tester(unittest.TestCase):
])(img) ])(img)
sum2 = result.sum() sum2 = result.sum()
assert sum2 > 0, "height: " + str(height) + " width: " \ assert sum2 > 0, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
assert sum2 > sum1, "height: " + str(height) + " width: " \ assert sum2 > sum1, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
def test_scale(self): def test_scale(self):
height = random.randint(24, 32) * 2 height = random.randint(24, 32) * 2
...@@ -100,19 +99,19 @@ class Tester(unittest.TestCase): ...@@ -100,19 +99,19 @@ class Tester(unittest.TestCase):
transforms.Pad(padding), transforms.Pad(padding),
transforms.ToTensor(), transforms.ToTensor(),
])(img) ])(img)
assert result.size(1) == height + 2*padding assert result.size(1) == height + 2 * padding
assert result.size(2) == width + 2*padding assert result.size(2) == width + 2 * padding
def test_lambda(self): def test_lambda(self):
trans = transforms.Lambda(lambda x: x.add(10)) trans = transforms.Lambda(lambda x: x.add(10))
x = torch.randn(10) x = torch.randn(10)
y = trans(x) y = trans(x)
assert(y.equal(torch.add(x, 10))) assert (y.equal(torch.add(x, 10)))
trans = transforms.Lambda(lambda x: x.add_(10)) trans = transforms.Lambda(lambda x: x.add_(10))
x = torch.randn(10) x = torch.randn(10)
y = trans(x) y = trans(x)
assert(y.equal(x)) assert (y.equal(x))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment