Commit 611c2348 authored by moskomule's avatar moskomule Committed by Soumith Chintala
Browse files

Expandusr (fixed mistakes) (#224)

* enable ~ expression in datasets

* fixed root in __init__ to self.root
parent c5e52baa
...@@ -68,7 +68,7 @@ class CIFAR10(data.Dataset): ...@@ -68,7 +68,7 @@ class CIFAR10(data.Dataset):
self.train_labels = [] self.train_labels = []
for fentry in self.train_list: for fentry in self.train_list:
f = fentry[0] f = fentry[0]
file = os.path.join(root, self.base_folder, f) file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb') fo = open(file, 'rb')
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
entry = pickle.load(fo) entry = pickle.load(fo)
...@@ -86,7 +86,7 @@ class CIFAR10(data.Dataset): ...@@ -86,7 +86,7 @@ class CIFAR10(data.Dataset):
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else: else:
f = self.test_list[0][0] f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f) file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb') fo = open(file, 'rb')
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
entry = pickle.load(fo) entry = pickle.load(fo)
......
...@@ -50,9 +50,9 @@ class MNIST(data.Dataset): ...@@ -50,9 +50,9 @@ class MNIST(data.Dataset):
if self.train: if self.train:
self.train_data, self.train_labels = torch.load( self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file)) os.path.join(self.root, self.processed_folder, self.training_file))
else: else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) self.test_data, self.test_labels = torch.load(os.path.join(self.root, self.processed_folder, self.test_file))
def __getitem__(self, index): def __getitem__(self, index):
""" """
......
...@@ -51,9 +51,9 @@ class PhotoTour(data.Dataset): ...@@ -51,9 +51,9 @@ class PhotoTour(data.Dataset):
def __init__(self, root, name, train=True, transform=None, download=False): def __init__(self, root, name, train=True, transform=None, download=False):
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
self.name = name self.name = name
self.data_dir = os.path.join(root, name) self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(root, '{}.zip'.format(name)) self.data_down = os.path.join(self.root, '{}.zip'.format(name))
self.data_file = os.path.join(root, '{}.pt'.format(name)) self.data_file = os.path.join(self.root, '{}.pt'.format(name))
self.train = train self.train = train
self.transform = transform self.transform = transform
......
...@@ -77,7 +77,7 @@ class STL10(CIFAR10): ...@@ -77,7 +77,7 @@ class STL10(CIFAR10):
self.test_list[0][0], self.test_list[1][0]) self.test_list[0][0], self.test_list[1][0])
class_file = os.path.join( class_file = os.path.join(
root, self.base_folder, self.class_names_file) self.root, self.base_folder, self.class_names_file)
if os.path.isfile(class_file): if os.path.isfile(class_file):
with open(class_file) as f: with open(class_file) as f:
self.classes = f.read().splitlines() self.classes = f.read().splitlines()
......
...@@ -63,7 +63,7 @@ class SVHN(data.Dataset): ...@@ -63,7 +63,7 @@ class SVHN(data.Dataset):
import scipy.io as sio import scipy.io as sio
# reading(loading) mat file as array # reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(root, self.filename)) loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
self.data = loaded_mat['X'] self.data = loaded_mat['X']
self.labels = loaded_mat['y'] self.labels = loaded_mat['y']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment