Commit a137e4f4 authored by Luke Yeager's avatar Luke Yeager Committed by Soumith Chintala
Browse files

[Lint] Fix most lint automatically with autopep8

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i
parent e9ec6ac8
[bdist_wheel] [bdist_wheel]
universal=1 universal=1
[pep8]
max-line-length = 120
[flake8]
max-line-length = 120
ignore = F401,F403
exclude = venv
...@@ -14,7 +14,7 @@ print(a[3]) ...@@ -14,7 +14,7 @@ print(a[3])
dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor()) dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=True, num_workers=2) shuffle=True, num_workers=2)
...@@ -31,10 +31,9 @@ for i, data in enumerate(dataloader, 0): ...@@ -31,10 +31,9 @@ for i, data in enumerate(dataloader, 0):
# except StopIteration: # except StopIteration:
# miter = dataloader.__iter__() # miter = dataloader.__iter__()
# return miter.next() # return miter.next()
# i=0 # i=0
# while True: # while True:
# print(i) # print(i)
# img, target = getBatch() # img, target = getBatch()
# i+=1 # i+=1
...@@ -20,14 +20,13 @@ parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', ...@@ -20,14 +20,13 @@ parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Data loading code # Data loading code
transform = transforms.Compose([ transform = transforms.Compose([
transforms.RandomSizedCrop(224), transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], transforms.Normalize(mean=[0.485, 0.456, 0.406],
std = [ 0.229, 0.224, 0.225 ]), std=[0.229, 0.224, 0.225]),
]) ])
traindir = os.path.join(args.data, 'train') traindir = os.path.join(args.data, 'train')
...@@ -47,4 +46,3 @@ if __name__ == "__main__": ...@@ -47,4 +46,3 @@ if __name__ == "__main__":
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
batch=(end_time - start_time) / float(batch_count), batch=(end_time - start_time) / float(batch_count),
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3)) image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
...@@ -11,36 +11,37 @@ if sys.version_info[0] == 2: ...@@ -11,36 +11,37 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
class CIFAR10(data.Dataset): class CIFAR10(data.Dataset):
base_folder = 'cifar-10-batches-py' base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a' tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [ train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
] ]
test_list = [ test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'], ['test_batch', '40351d587109b95175f43aff81a1287e'],
] ]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' raise RuntimeError('Dataset not found or corrupted.'
+ ' You can use download=True to download it') + ' You can use download=True to download it')
# now load the picked numpy arrays # now load the picked numpy arrays
if self.train: if self.train:
self.train_data = [] self.train_data = []
...@@ -83,10 +84,10 @@ class CIFAR10(data.Dataset): ...@@ -83,10 +84,10 @@ class CIFAR10(data.Dataset):
img, target = self.train_data[index], self.train_labels[index] img, target = self.train_data[index], self.train_labels[index]
else: else:
img, target = self.test_data[index], self.test_labels[index] img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
img = Image.fromarray(np.transpose(img, (1,2,0))) img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
...@@ -134,7 +135,7 @@ class CIFAR10(data.Dataset): ...@@ -134,7 +135,7 @@ class CIFAR10(data.Dataset):
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
# downloads file # downloads file
if os.path.isfile(fpath) and \ if os.path.isfile(fpath) and \
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5: hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5:
...@@ -147,7 +148,7 @@ class CIFAR10(data.Dataset): ...@@ -147,7 +148,7 @@ class CIFAR10(data.Dataset):
cwd = os.getcwd() cwd = os.getcwd()
print('Extracting tar file') print('Extracting tar file')
tar = tarfile.open(fpath, "r:gz") tar = tarfile.open(fpath, "r:gz")
os.chdir(root) os.chdir(root)
tar.extractall() tar.extractall()
tar.close() tar.close()
os.chdir(cwd) os.chdir(cwd)
...@@ -160,10 +161,9 @@ class CIFAR100(CIFAR10): ...@@ -160,10 +161,9 @@ class CIFAR100(CIFAR10):
filename = "cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [ train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'], ['train', '16019d7e3df5f24257cddd939b257f8d'],
] ]
test_list = [ test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
] ]
...@@ -3,7 +3,9 @@ from PIL import Image ...@@ -3,7 +3,9 @@ from PIL import Image
import os import os
import os.path import os.path
class CocoCaptions(data.Dataset): class CocoCaptions(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root self.root = root
...@@ -15,7 +17,7 @@ class CocoCaptions(data.Dataset): ...@@ -15,7 +17,7 @@ class CocoCaptions(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
coco = self.coco coco = self.coco
img_id = self.ids[index] img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds = img_id) ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids) anns = coco.loadAnns(ann_ids)
target = [ann['caption'] for ann in anns] target = [ann['caption'] for ann in anns]
...@@ -33,7 +35,9 @@ class CocoCaptions(data.Dataset): ...@@ -33,7 +35,9 @@ class CocoCaptions(data.Dataset):
def __len__(self): def __len__(self):
return len(self.ids) return len(self.ids)
class CocoDetection(data.Dataset): class CocoDetection(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root self.root = root
...@@ -45,7 +49,7 @@ class CocoDetection(data.Dataset): ...@@ -45,7 +49,7 @@ class CocoDetection(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
coco = self.coco coco = self.coco
img_id = self.ids[index] img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds = img_id) ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids) target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name'] path = coco.loadImgs(img_id)[0]['file_name']
......
...@@ -9,15 +9,18 @@ IMG_EXTENSIONS = [ ...@@ -9,15 +9,18 @@ IMG_EXTENSIONS = [
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
] ]
def is_image_file(filename): def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def find_classes(dir): def find_classes(dir):
classes = os.listdir(dir) classes = os.listdir(dir)
classes.sort() classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))} class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx return classes, class_to_idx
def make_dataset(dir, class_to_idx): def make_dataset(dir, class_to_idx):
images = [] images = []
for target in os.listdir(dir): for target in os.listdir(dir):
...@@ -39,6 +42,7 @@ def default_loader(path): ...@@ -39,6 +42,7 @@ def default_loader(path):
class ImageFolder(data.Dataset): class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
loader=default_loader): loader=default_loader):
classes, class_to_idx = find_classes(root) classes, class_to_idx = find_classes(root)
......
...@@ -10,7 +10,9 @@ if sys.version_info[0] == 2: ...@@ -10,7 +10,9 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
class LSUNClass(data.Dataset): class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None): def __init__(self, db_path, transform=None, target_transform=None):
import lmdb import lmdb
self.db_path = db_path self.db_path = db_path
...@@ -20,11 +22,11 @@ class LSUNClass(data.Dataset): ...@@ -20,11 +22,11 @@ class LSUNClass(data.Dataset):
self.length = txn.stat()['entries'] self.length = txn.stat()['entries']
cache_file = '_cache_' + db_path.replace('/', '_') cache_file = '_cache_' + db_path.replace('/', '_')
if os.path.isfile(cache_file): if os.path.isfile(cache_file):
self.keys = pickle.load( open( cache_file, "rb" ) ) self.keys = pickle.load(open(cache_file, "rb"))
else: else:
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
self.keys = [ key for key, _ in txn.cursor() ] self.keys = [key for key, _ in txn.cursor()]
pickle.dump( self.keys, open( cache_file, "wb" ) ) pickle.dump(self.keys, open(cache_file, "wb"))
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -53,11 +55,13 @@ class LSUNClass(data.Dataset): ...@@ -53,11 +55,13 @@ class LSUNClass(data.Dataset):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')' return self.__class__.__name__ + ' (' + self.db_path + ')'
class LSUN(data.Dataset): class LSUN(data.Dataset):
""" """
db_path = root directory for the database files db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...] classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
""" """
def __init__(self, db_path, classes='train', def __init__(self, db_path, classes='train',
transform=None, target_transform=None): transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
...@@ -73,13 +77,13 @@ class LSUN(data.Dataset): ...@@ -73,13 +77,13 @@ class LSUN(data.Dataset):
c_short.pop(len(c_short) - 1) c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short) c_short = '_'.join(c_short)
if c_short not in categories: if c_short not in categories:
raise(ValueError('Unknown LSUN class: ' + c_short + '.'\ raise(ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories))) 'Options are: ' + str(categories)))
c_short = c.split('_') c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1) c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts: if c_short not in dset_opts:
raise(ValueError('Unknown postfix: ' + c_short + '.'\ raise(ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts))) 'Options are: ' + str(dset_opts)))
else: else:
raise(ValueError('Unknown option for classes')) raise(ValueError('Unknown option for classes'))
self.classes = classes self.classes = classes
...@@ -88,8 +92,8 @@ class LSUN(data.Dataset): ...@@ -88,8 +92,8 @@ class LSUN(data.Dataset):
self.dbs = [] self.dbs = []
for c in self.classes: for c in self.classes:
self.dbs.append(LSUNClass( self.dbs.append(LSUNClass(
db_path = db_path + '/' + c + '_lmdb', db_path=db_path + '/' + c + '_lmdb',
transform = transform)) transform=transform))
self.indices = [] self.indices = []
count = 0 count = 0
...@@ -128,9 +132,9 @@ if __name__ == '__main__': ...@@ -128,9 +132,9 @@ if __name__ == '__main__':
#lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb') #lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
#a = lsun[0] #a = lsun[0]
lsun = LSUN(db_path='/home/soumith/local/lsun/train', lsun = LSUN(db_path='/home/soumith/local/lsun/train',
classes=['bedroom_train', 'church_outdoor_train']) classes=['bedroom_train', 'church_outdoor_train'])
print(lsun.classes) print(lsun.classes)
print(lsun.dbs) print(lsun.dbs)
a, t = lsun[len(lsun)-1] a, t = lsun[len(lsun) - 1]
print(a) print(a)
print(t) print(t)
...@@ -9,6 +9,7 @@ import json ...@@ -9,6 +9,7 @@ import json
import codecs import codecs
import numpy as np import numpy as np
class MNIST(data.Dataset): class MNIST(data.Dataset):
urls = [ urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
...@@ -25,7 +26,7 @@ class MNIST(data.Dataset): ...@@ -25,7 +26,7 @@ class MNIST(data.Dataset):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
if download: if download:
self.download() self.download()
...@@ -35,7 +36,8 @@ class MNIST(data.Dataset): ...@@ -35,7 +36,8 @@ class MNIST(data.Dataset):
+ ' You can use download=True to download it') + ' You can use download=True to download it')
if self.train: if self.train:
self.train_data, self.train_labels = torch.load(os.path.join(root, self.processed_folder, self.training_file)) self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else: else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
...@@ -65,7 +67,7 @@ class MNIST(data.Dataset): ...@@ -65,7 +67,7 @@ class MNIST(data.Dataset):
def _check_exists(self): def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self): def download(self):
from six.moves import urllib from six.moves import urllib
...@@ -92,7 +94,7 @@ class MNIST(data.Dataset): ...@@ -92,7 +94,7 @@ class MNIST(data.Dataset):
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(data.read()) f.write(data.read())
with open(file_path.replace('.gz', ''), 'wb') as out_f, \ with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f: gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read()) out_f.write(zip_f.read())
os.unlink(file_path) os.unlink(file_path)
...@@ -114,14 +116,17 @@ class MNIST(data.Dataset): ...@@ -114,14 +116,17 @@ class MNIST(data.Dataset):
print('Done!') print('Done!')
def get_int(b): def get_int(b):
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, 'hex'), 16)
def parse_byte(b): def parse_byte(b):
if isinstance(b, str): if isinstance(b, str):
return ord(b) return ord(b)
return b return b
def read_label_file(path): def read_label_file(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
data = f.read() data = f.read()
...@@ -131,6 +136,7 @@ def read_label_file(path): ...@@ -131,6 +136,7 @@ def read_label_file(path):
assert len(labels) == length assert len(labels) == length
return torch.LongTensor(labels) return torch.LongTensor(labels)
def read_image_file(path): def read_image_file(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
data = f.read() data = f.read()
......
...@@ -11,6 +11,7 @@ model_urls = { ...@@ -11,6 +11,7 @@ model_urls = {
class AlexNet(nn.Module): class AlexNet(nn.Module):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
......
...@@ -94,6 +94,7 @@ class Bottleneck(nn.Module): ...@@ -94,6 +94,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000): def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64 self.inplanes = 64
super(ResNet, self).__init__() super(ResNet, self).__init__()
......
...@@ -14,8 +14,9 @@ model_urls = { ...@@ -14,8 +14,9 @@ model_urls = {
class Fire(nn.Module): class Fire(nn.Module):
def __init__(self, inplanes, squeeze_planes, def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes): expand1x1_planes, expand3x3_planes):
super(Fire, self).__init__() super(Fire, self).__init__()
self.inplanes = inplanes self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
...@@ -36,6 +37,7 @@ class Fire(nn.Module): ...@@ -36,6 +37,7 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module): class SqueezeNet(nn.Module):
def __init__(self, version=1.0, num_classes=1000): def __init__(self, version=1.0, num_classes=1000):
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]: if version not in [1.0, 1.1]:
......
...@@ -18,6 +18,7 @@ model_urls = { ...@@ -18,6 +18,7 @@ model_urls = {
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, features): def __init__(self, features):
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import numbers import numbers
import types import types
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -19,6 +20,7 @@ class Compose(object): ...@@ -19,6 +20,7 @@ class Compose(object):
>>> transforms.ToTensor(), >>> transforms.ToTensor(),
>>> ]) >>> ])
""" """
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
...@@ -32,6 +34,7 @@ class ToTensor(object): ...@@ -32,6 +34,7 @@ class ToTensor(object):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
""" """
def __call__(self, pic): def __call__(self, pic):
if isinstance(pic, np.ndarray): if isinstance(pic, np.ndarray):
# handle numpy array # handle numpy array
...@@ -56,12 +59,13 @@ class ToPILImage(object): ...@@ -56,12 +59,13 @@ class ToPILImage(object):
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255] to a PIL.Image of range [0, 255]
""" """
def __call__(self, pic): def __call__(self, pic):
npimg = pic npimg = pic
mode = None mode = None
if not isinstance(npimg, np.ndarray): if not isinstance(npimg, np.ndarray):
npimg = pic.mul(255).byte().numpy() npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1,2,0)) npimg = np.transpose(npimg, (1, 2, 0))
if npimg.shape[2] == 1: if npimg.shape[2] == 1:
npimg = npimg[:, :, 0] npimg = npimg[:, :, 0]
...@@ -75,6 +79,7 @@ class Normalize(object): ...@@ -75,6 +79,7 @@ class Normalize(object):
will normalize each channel of the torch.*Tensor, i.e. will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std channel = (channel - mean) / std
""" """
def __init__(self, mean, std): def __init__(self, mean, std):
self.mean = mean self.mean = mean
self.std = std self.std = std
...@@ -94,6 +99,7 @@ class Scale(object): ...@@ -94,6 +99,7 @@ class Scale(object):
size: size of the smaller edge size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR interpolation: Default: PIL.Image.BILINEAR
""" """
def __init__(self, size, interpolation=Image.BILINEAR): def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
...@@ -117,6 +123,7 @@ class CenterCrop(object): ...@@ -117,6 +123,7 @@ class CenterCrop(object):
the given size. size can be a tuple (target_height, target_width) the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size) or an integer, in which case the target will be of a square shape (size, size)
""" """
def __init__(self, size): def __init__(self, size):
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
self.size = (int(size), int(size)) self.size = (int(size), int(size))
...@@ -133,6 +140,7 @@ class CenterCrop(object): ...@@ -133,6 +140,7 @@ class CenterCrop(object):
class Pad(object): class Pad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value""" """Pads the given PIL.Image on all sides with the given "pad" value"""
def __init__(self, padding, fill=0): def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number) assert isinstance(padding, numbers.Number)
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
...@@ -142,8 +150,10 @@ class Pad(object): ...@@ -142,8 +150,10 @@ class Pad(object):
def __call__(self, img): def __call__(self, img):
return ImageOps.expand(img, border=self.padding, fill=self.fill) return ImageOps.expand(img, border=self.padding, fill=self.fill)
class Lambda(object): class Lambda(object):
"""Applies a lambda as a transform.""" """Applies a lambda as a transform."""
def __init__(self, lambd): def __init__(self, lambd):
assert type(lambd) is types.LambdaType assert type(lambd) is types.LambdaType
self.lambd = lambd self.lambd = lambd
...@@ -157,6 +167,7 @@ class RandomCrop(object): ...@@ -157,6 +167,7 @@ class RandomCrop(object):
the given size. size can be a tuple (target_height, target_width) the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size) or an integer, in which case the target will be of a square shape (size, size)
""" """
def __init__(self, size, padding=0): def __init__(self, size, padding=0):
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
self.size = (int(size), int(size)) self.size = (int(size), int(size))
...@@ -181,6 +192,7 @@ class RandomCrop(object): ...@@ -181,6 +192,7 @@ class RandomCrop(object):
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5 """Randomly horizontally flips the given PIL.Image with a probability of 0.5
""" """
def __call__(self, img): def __call__(self, img):
if random.random() < 0.5: if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT) return img.transpose(Image.FLIP_LEFT_RIGHT)
...@@ -194,6 +206,7 @@ class RandomSizedCrop(object): ...@@ -194,6 +206,7 @@ class RandomSizedCrop(object):
size: size of the smaller edge size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR interpolation: Default: PIL.Image.BILINEAR
""" """
def __init__(self, size, interpolation=Image.BILINEAR): def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
......
import torch import torch
import math import math
def make_grid(tensor, nrow=8, padding=2): def make_grid(tensor, nrow=8, padding=2):
""" """
Given a 4D mini-batch Tensor of shape (B x C x H x W), Given a 4D mini-batch Tensor of shape (B x C x H x W),
...@@ -15,13 +16,13 @@ def make_grid(tensor, nrow=8, padding=2): ...@@ -15,13 +16,13 @@ def make_grid(tensor, nrow=8, padding=2):
tensor = tensorlist[0].new(size) tensor = tensorlist[0].new(size)
for i in range(numImages): for i in range(numImages):
tensor[i].copy_(tensorlist[i]) tensor[i].copy_(tensorlist[i])
if tensor.dim() == 2: # single image H x W if tensor.dim() == 2: # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1)) tensor = tensor.view(1, tensor.size(0), tensor.size(1))
if tensor.dim() == 3: # single image if tensor.dim() == 3: # single image
if tensor.size(0) == 1: if tensor.size(0) == 1:
tensor = torch.cat((tensor, tensor, tensor), 0) tensor = torch.cat((tensor, tensor, tensor), 0)
return tensor return tensor
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1) tensor = torch.cat((tensor, tensor, tensor), 1)
# make the mini-batch of images into a grid # make the mini-batch of images into a grid
nmaps = tensor.size(0) nmaps = tensor.size(0)
...@@ -34,8 +35,8 @@ def make_grid(tensor, nrow=8, padding=2): ...@@ -34,8 +35,8 @@ def make_grid(tensor, nrow=8, padding=2):
for x in range(xmaps): for x in range(xmaps):
if k >= nmaps: if k >= nmaps:
break break
grid.narrow(1, y*height+1+padding//2,height-padding)\ grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
.narrow(2, x*width+1+padding//2, width-padding)\ .narrow(2, x * width + 1 + padding // 2, width - padding)\
.copy_(tensor[k]) .copy_(tensor[k])
k = k + 1 k = k + 1
return grid return grid
...@@ -49,6 +50,6 @@ def save_image(tensor, filename, nrow=8, padding=2): ...@@ -49,6 +50,6 @@ def save_image(tensor, filename, nrow=8, padding=2):
from PIL import Image from PIL import Image
tensor = tensor.cpu() tensor = tensor.cpu()
grid = make_grid(tensor, nrow=nrow, padding=padding) grid = make_grid(tensor, nrow=nrow, padding=padding)
ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0,2).transpose(0,1).numpy() ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0, 2).transpose(0, 1).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(filename) im.save(filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment