Commit 3b6c6706 authored by reynoldscem's avatar reynoldscem Committed by Soumith Chintala
Browse files

Fix for issue #447 - STL dataset returns test fold if fold is misspecified (#449)

parent c76ac7ff
...@@ -41,9 +41,14 @@ class STL10(CIFAR10): ...@@ -41,9 +41,14 @@ class STL10(CIFAR10):
['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'],
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
] ]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__(self, root, split='train', def __init__(self, root, split='train',
transform=None, target_transform=None, download=False): transform=None, target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment