Unverified Commit 5be137e2 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix STL10 repr (#969)

* Fix STL10 repr

* Do not inherit from Cifar10

* Make it safer to inherit from VisionDataset
parent c59f0474
...@@ -3,10 +3,12 @@ from PIL import Image ...@@ -3,10 +3,12 @@ from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
from .cifar import CIFAR10
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract
class STL10(CIFAR10):
class STL10(VisionDataset):
"""`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset. """`STL10 <https://cs.stanford.edu/~acoates/stl10/>`_ Dataset.
Args: Args:
...@@ -46,7 +48,7 @@ class STL10(CIFAR10): ...@@ -46,7 +48,7 @@ class STL10(CIFAR10):
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits), split, ', '.join(self.splits),
)) ))
self.root = os.path.expanduser(root) super(STL10, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.split = split # train/test/unlabeled set self.split = split # train/test/unlabeled set
...@@ -129,5 +131,20 @@ class STL10(CIFAR10): ...@@ -129,5 +131,20 @@ class STL10(CIFAR10):
return images, labels return images, labels
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract(self.url, self.root, self.filename, self.tgz_md5)
def extra_repr(self): def extra_repr(self):
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
...@@ -37,7 +37,7 @@ class VisionDataset(data.Dataset): ...@@ -37,7 +37,7 @@ class VisionDataset(data.Dataset):
if self.root is not None: if self.root is not None:
body.append("Root location: {}".format(self.root)) body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines() body += self.extra_repr().splitlines()
if self.transforms is not None: if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)] body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body] lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines) return '\n'.join(lines)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment