Commit a137e4f4 authored by Luke Yeager's avatar Luke Yeager Committed by Soumith Chintala
Browse files

[Lint] Fix most lint automatically with autopep8

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i
parent e9ec6ac8
[bdist_wheel]
universal=1
[pep8]
max-line-length = 120
[flake8]
max-line-length = 120
ignore = F401,F403
exclude = venv
......@@ -14,7 +14,7 @@ print(a[3])
dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=True, num_workers=2)
......@@ -31,10 +31,9 @@ for i, data in enumerate(dataloader, 0):
# except StopIteration:
# miter = dataloader.__iter__()
# return miter.next()
# i=0
# while True:
# print(i)
# img, target = getBatch()
# i+=1
......@@ -20,14 +20,13 @@ parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
if __name__ == "__main__":
args = parser.parse_args()
# Data loading code
transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
std = [ 0.229, 0.224, 0.225 ]),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
traindir = os.path.join(args.data, 'train')
......@@ -47,4 +46,3 @@ if __name__ == "__main__":
dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0),
batch=(end_time - start_time) / float(batch_count),
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
......@@ -11,36 +11,37 @@ if sys.version_info[0] == 2:
else:
import pickle
class CIFAR10(data.Dataset):
base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.'
raise RuntimeError('Dataset not found or corrupted.'
+ ' You can use download=True to download it')
# now load the picked numpy arrays
if self.train:
self.train_data = []
......@@ -83,10 +84,10 @@ class CIFAR10(data.Dataset):
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1,2,0)))
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None:
img = self.transform(img)
......@@ -134,7 +135,7 @@ class CIFAR10(data.Dataset):
if self._check_integrity():
print('Files already downloaded and verified')
return
# downloads file
if os.path.isfile(fpath) and \
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5:
......@@ -147,7 +148,7 @@ class CIFAR10(data.Dataset):
cwd = os.getcwd()
print('Extracting tar file')
tar = tarfile.open(fpath, "r:gz")
os.chdir(root)
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
......@@ -160,10 +161,9 @@ class CIFAR100(CIFAR10):
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
......@@ -3,7 +3,9 @@ from PIL import Image
import os
import os.path
class CocoCaptions(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
......@@ -15,7 +17,7 @@ class CocoCaptions(data.Dataset):
def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds = img_id)
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)
target = [ann['caption'] for ann in anns]
......@@ -33,7 +35,9 @@ class CocoCaptions(data.Dataset):
def __len__(self):
return len(self.ids)
class CocoDetection(data.Dataset):
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
......@@ -45,7 +49,7 @@ class CocoDetection(data.Dataset):
def __getitem__(self, index):
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds = img_id)
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
path = coco.loadImgs(img_id)[0]['file_name']
......
......@@ -9,15 +9,18 @@ IMG_EXTENSIONS = [
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def find_classes(dir):
classes = os.listdir(dir)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx):
images = []
for target in os.listdir(dir):
......@@ -39,6 +42,7 @@ def default_loader(path):
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
......
......@@ -10,7 +10,9 @@ if sys.version_info[0] == 2:
else:
import pickle
class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
......@@ -20,11 +22,11 @@ class LSUNClass(data.Dataset):
self.length = txn.stat()['entries']
cache_file = '_cache_' + db_path.replace('/', '_')
if os.path.isfile(cache_file):
self.keys = pickle.load( open( cache_file, "rb" ) )
self.keys = pickle.load(open(cache_file, "rb"))
else:
with self.env.begin(write=False) as txn:
self.keys = [ key for key, _ in txn.cursor() ]
pickle.dump( self.keys, open( cache_file, "wb" ) )
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
self.transform = transform
self.target_transform = target_transform
......@@ -53,11 +55,13 @@ class LSUNClass(data.Dataset):
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
class LSUN(data.Dataset):
"""
db_path = root directory for the database files
classes = 'train' | 'val' | 'test' | ['bedroom_train', 'church_train', ...]
"""
def __init__(self, db_path, classes='train',
transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
......@@ -73,13 +77,13 @@ class LSUN(data.Dataset):
c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short)
if c_short not in categories:
raise(ValueError('Unknown LSUN class: ' + c_short + '.'\
'Options are: ' + str(categories)))
raise(ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories)))
c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts:
raise(ValueError('Unknown postfix: ' + c_short + '.'\
'Options are: ' + str(dset_opts)))
raise(ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts)))
else:
raise(ValueError('Unknown option for classes'))
self.classes = classes
......@@ -88,8 +92,8 @@ class LSUN(data.Dataset):
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
db_path = db_path + '/' + c + '_lmdb',
transform = transform))
db_path=db_path + '/' + c + '_lmdb',
transform=transform))
self.indices = []
count = 0
......@@ -128,9 +132,9 @@ if __name__ == '__main__':
#lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
#a = lsun[0]
lsun = LSUN(db_path='/home/soumith/local/lsun/train',
classes=['bedroom_train', 'church_outdoor_train'])
classes=['bedroom_train', 'church_outdoor_train'])
print(lsun.classes)
print(lsun.dbs)
a, t = lsun[len(lsun)-1]
a, t = lsun[len(lsun) - 1]
print(a)
print(t)
......@@ -9,6 +9,7 @@ import json
import codecs
import numpy as np
class MNIST(data.Dataset):
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
......@@ -25,7 +26,7 @@ class MNIST(data.Dataset):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
self.train = train # training set or test set
if download:
self.download()
......@@ -35,7 +36,8 @@ class MNIST(data.Dataset):
+ ' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(os.path.join(root, self.processed_folder, self.training_file))
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
......@@ -65,7 +67,7 @@ class MNIST(data.Dataset):
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self):
from six.moves import urllib
......@@ -92,7 +94,7 @@ class MNIST(data.Dataset):
with open(file_path, 'wb') as f:
f.write(data.read())
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path)
......@@ -114,14 +116,17 @@ class MNIST(data.Dataset):
print('Done!')
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def parse_byte(b):
if isinstance(b, str):
return ord(b)
return b
def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
......@@ -131,6 +136,7 @@ def read_label_file(path):
assert len(labels) == length
return torch.LongTensor(labels)
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
......
......@@ -11,6 +11,7 @@ model_urls = {
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
......
......@@ -94,6 +94,7 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
......
......@@ -14,8 +14,9 @@ model_urls = {
class Fire(nn.Module):
def __init__(self, inplanes, squeeze_planes,
expand1x1_planes, expand3x3_planes):
expand1x1_planes, expand3x3_planes):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
......@@ -36,6 +37,7 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module):
def __init__(self, version=1.0, num_classes=1000):
super(SqueezeNet, self).__init__()
if version not in [1.0, 1.1]:
......
......@@ -18,6 +18,7 @@ model_urls = {
class VGG(nn.Module):
def __init__(self, features):
super(VGG, self).__init__()
self.features = features
......
......@@ -7,6 +7,7 @@ import numpy as np
import numbers
import types
class Compose(object):
"""Composes several transforms together.
......@@ -19,6 +20,7 @@ class Compose(object):
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
......@@ -32,6 +34,7 @@ class ToTensor(object):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
......@@ -56,12 +59,13 @@ class ToPILImage(object):
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""
def __call__(self, pic):
npimg = pic
mode = None
if not isinstance(npimg, np.ndarray):
npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1,2,0))
npimg = np.transpose(npimg, (1, 2, 0))
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
......@@ -75,6 +79,7 @@ class Normalize(object):
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
......@@ -94,6 +99,7 @@ class Scale(object):
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
......@@ -117,6 +123,7 @@ class CenterCrop(object):
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
......@@ -133,6 +140,7 @@ class CenterCrop(object):
class Pad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value"""
def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number)
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
......@@ -142,8 +150,10 @@ class Pad(object):
def __call__(self, img):
return ImageOps.expand(img, border=self.padding, fill=self.fill)
class Lambda(object):
"""Applies a lambda as a transform."""
def __init__(self, lambd):
assert type(lambd) is types.LambdaType
self.lambd = lambd
......@@ -157,6 +167,7 @@ class RandomCrop(object):
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
......@@ -181,6 +192,7 @@ class RandomCrop(object):
class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __call__(self, img):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
......@@ -194,6 +206,7 @@ class RandomSizedCrop(object):
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
......
import torch
import math
def make_grid(tensor, nrow=8, padding=2):
"""
Given a 4D mini-batch Tensor of shape (B x C x H x W),
......@@ -15,13 +16,13 @@ def make_grid(tensor, nrow=8, padding=2):
tensor = tensorlist[0].new(size)
for i in range(numImages):
tensor[i].copy_(tensorlist[i])
if tensor.dim() == 2: # single image H x W
if tensor.dim() == 2: # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
if tensor.dim() == 3: # single image
if tensor.dim() == 3: # single image
if tensor.size(0) == 1:
tensor = torch.cat((tensor, tensor, tensor), 0)
return tensor
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
......@@ -34,8 +35,8 @@ def make_grid(tensor, nrow=8, padding=2):
for x in range(xmaps):
if k >= nmaps:
break
grid.narrow(1, y*height+1+padding//2,height-padding)\
.narrow(2, x*width+1+padding//2, width-padding)\
grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
.narrow(2, x * width + 1 + padding // 2, width - padding)\
.copy_(tensor[k])
k = k + 1
return grid
......@@ -49,6 +50,6 @@ def save_image(tensor, filename, nrow=8, padding=2):
from PIL import Image
tensor = tensor.cpu()
grid = make_grid(tensor, nrow=nrow, padding=padding)
ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0,2).transpose(0,1).numpy()
ndarr = grid.mul(0.5).add(0.5).mul(255).byte().transpose(0, 2).transpose(0, 1).numpy()
im = Image.fromarray(ndarr)
im.save(filename)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment