Commit c1f88ef1 authored by Ryuichiro Hataya's avatar Ryuichiro Hataya Committed by Soumith Chintala
Browse files

enable ~ expression in datasets (#186)

parent 432aa00d
......@@ -50,7 +50,7 @@ class CIFAR10(data.Dataset):
def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
self.root = root
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
......
......@@ -44,7 +44,7 @@ class CocoCaptions(data.Dataset):
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO
self.root = root
self.root = os.path.expanduser(root)
self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys())
self.transform = transform
......
......@@ -23,6 +23,7 @@ def find_classes(dir):
def make_dataset(dir, class_to_idx):
images = []
dir = os.path.expanduser(dir)
for target in os.listdir(dir):
d = os.path.join(dir, target)
if not os.path.isdir(d):
......
......@@ -36,7 +36,7 @@ class MNIST(data.Dataset):
test_file = 'test.pt'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
......
......@@ -49,7 +49,7 @@ class PhotoTour(data.Dataset):
matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False):
self.root = root
self.root = os.path.expanduser(root)
self.name = name
self.data_dir = os.path.join(root, name)
self.data_down = os.path.join(root, '{}.zip'.format(name))
......
......@@ -44,7 +44,7 @@ class STL10(CIFAR10):
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = root
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.split = split # train/test/unlabeled set
......
......@@ -38,7 +38,7 @@ class SVHN(data.Dataset):
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = root
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set
......
......@@ -21,6 +21,7 @@ def check_integrity(fpath, md5):
def download_url(url, root, filename, md5):
from six.moves import urllib
root = os.path.expanduser(root)
fpath = os.path.join(root, filename)
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment