"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "92ab6a8d5f0596156b1c96fc1cc90041fe2336d7"
Commit c1f88ef1 authored by Ryuichiro Hataya's avatar Ryuichiro Hataya Committed by Soumith Chintala
Browse files

enable ~ expression in datasets (#186)

parent 432aa00d
...@@ -50,7 +50,7 @@ class CIFAR10(data.Dataset): ...@@ -50,7 +50,7 @@ class CIFAR10(data.Dataset):
def __init__(self, root, train=True, def __init__(self, root, train=True,
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
self.root = root self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
......
...@@ -44,7 +44,7 @@ class CocoCaptions(data.Dataset): ...@@ -44,7 +44,7 @@ class CocoCaptions(data.Dataset):
""" """
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root self.root = os.path.expanduser(root)
self.coco = COCO(annFile) self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys()) self.ids = list(self.coco.imgs.keys())
self.transform = transform self.transform = transform
......
...@@ -23,6 +23,7 @@ def find_classes(dir): ...@@ -23,6 +23,7 @@ def find_classes(dir):
def make_dataset(dir, class_to_idx): def make_dataset(dir, class_to_idx):
images = [] images = []
dir = os.path.expanduser(dir)
for target in os.listdir(dir): for target in os.listdir(dir):
d = os.path.join(dir, target) d = os.path.join(dir, target)
if not os.path.isdir(d): if not os.path.isdir(d):
......
...@@ -36,7 +36,7 @@ class MNIST(data.Dataset): ...@@ -36,7 +36,7 @@ class MNIST(data.Dataset):
test_file = 'test.pt' test_file = 'test.pt'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
......
...@@ -49,7 +49,7 @@ class PhotoTour(data.Dataset): ...@@ -49,7 +49,7 @@ class PhotoTour(data.Dataset):
matches_files = 'm50_100000_100000_0.txt' matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False): def __init__(self, root, name, train=True, transform=None, download=False):
self.root = root self.root = os.path.expanduser(root)
self.name = name self.name = name
self.data_dir = os.path.join(root, name) self.data_dir = os.path.join(root, name)
self.data_down = os.path.join(root, '{}.zip'.format(name)) self.data_down = os.path.join(root, '{}.zip'.format(name))
......
...@@ -44,7 +44,7 @@ class STL10(CIFAR10): ...@@ -44,7 +44,7 @@ class STL10(CIFAR10):
def __init__(self, root, split='train', def __init__(self, root, split='train',
transform=None, target_transform=None, download=False): transform=None, target_transform=None, download=False):
self.root = root self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.split = split # train/test/unlabeled set self.split = split # train/test/unlabeled set
......
...@@ -38,7 +38,7 @@ class SVHN(data.Dataset): ...@@ -38,7 +38,7 @@ class SVHN(data.Dataset):
def __init__(self, root, split='train', def __init__(self, root, split='train',
transform=None, target_transform=None, download=False): transform=None, target_transform=None, download=False):
self.root = root self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.split = split # training set or test set or extra set self.split = split # training set or test set or extra set
......
...@@ -21,6 +21,7 @@ def check_integrity(fpath, md5): ...@@ -21,6 +21,7 @@ def check_integrity(fpath, md5):
def download_url(url, root, filename, md5): def download_url(url, root, filename, md5):
from six.moves import urllib from six.moves import urllib
root = os.path.expanduser(root)
fpath = os.path.join(root, filename) fpath = os.path.join(root, filename)
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment