Unverified Commit 5be137e2 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix STL10 repr (#969)

* Fix STL10 repr

* Do not inherit from Cifar10

* Make it safer to inherit from VisionDataset
parent c59f0474
......@@ -3,10 +3,12 @@ from PIL import Image
import os
import os.path
import numpy as np
from .cifar import CIFAR10
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract
class STL10(CIFAR10):
class STL10(VisionDataset):
"""`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
Args:
......@@ -46,7 +48,7 @@ class STL10(CIFAR10):
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.root = os.path.expanduser(root)
super(STL10, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.split = split # train/test/unlabeled set
......@@ -129,5 +131,20 @@ class STL10(CIFAR10):
return images, labels
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
......@@ -37,7 +37,7 @@ class VisionDataset(data.Dataset):
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if self.transforms is not None:
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment