"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "869093e890ffface86430e378cc71ebcc080f2cf"
Commit 579eebea authored by Edouard Oyallon's avatar Edouard Oyallon Committed by Francisco Massa
Browse files

Enhancement of the STL10 loader (#914)

* modif of the STL10 loader

* missing space
parent 5be137e2
...@@ -16,6 +16,9 @@ class STL10(VisionDataset): ...@@ -16,6 +16,9 @@ class STL10(VisionDataset):
``stl10_binary`` exists. ``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}. split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected. Accordingly dataset is selected.
folds (int, optional): One of {0-9} or None.
For training, loads one of the 10 pre-defined folds of 1k samples for the
standard evaluation procedure. If no value is passed, loads the 5k samples.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
...@@ -30,6 +33,7 @@ class STL10(VisionDataset): ...@@ -30,6 +33,7 @@ class STL10(VisionDataset):
filename = "stl10_binary.tar.gz" filename = "stl10_binary.tar.gz"
tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb' tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb'
class_names_file = 'class_names.txt' class_names_file = 'class_names.txt'
folds_list_file = 'fold_indices.txt'
train_list = [ train_list = [
['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'], ['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'],
['train_y.bin', '5a34089d4802c674881badbb80307741'], ['train_y.bin', '5a34089d4802c674881badbb80307741'],
...@@ -42,7 +46,7 @@ class STL10(VisionDataset): ...@@ -42,7 +46,7 @@ class STL10(VisionDataset):
] ]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test') splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__(self, root, split='train', def __init__(self, root, split='train', folds=None,
transform=None, target_transform=None, download=False): transform=None, target_transform=None, download=False):
if split not in self.splits: if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
...@@ -52,6 +56,7 @@ class STL10(VisionDataset): ...@@ -52,6 +56,7 @@ class STL10(VisionDataset):
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.split = split # train/test/unlabeled set self.split = split # train/test/unlabeled set
self.folds = folds # one of the 10 pre-defined folds or the full dataset
if download: if download:
self.download() self.download()
...@@ -65,9 +70,12 @@ class STL10(VisionDataset): ...@@ -65,9 +70,12 @@ class STL10(VisionDataset):
if self.split == 'train': if self.split == 'train':
self.data, self.labels = self.__loadfile( self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0]) self.train_list[0][0], self.train_list[1][0])
self.__load_folds(folds)
elif self.split == 'train+unlabeled': elif self.split == 'train+unlabeled':
self.data, self.labels = self.__loadfile( self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0]) self.train_list[0][0], self.train_list[1][0])
self.__load_folds(folds)
unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data)) self.data = np.concatenate((self.data, unlabeled_data))
self.labels = np.concatenate( self.labels = np.concatenate(
...@@ -148,3 +156,16 @@ class STL10(VisionDataset): ...@@ -148,3 +156,16 @@ class STL10(VisionDataset):
def extra_repr(self): def extra_repr(self):
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
def __load_folds(self, folds):
# loads one of the folds if specified
if isinstance(folds, int):
if folds >= 0 and folds < 10:
path_to_folds = os.path.join(
self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
else:
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment