Unverified Commit ced96a0c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix STL10 folds (#3353)

* fix STL10 folds

* use np.int64 over np.long
parent 6949b893
...@@ -181,7 +181,7 @@ class STL10(VisionDataset): ...@@ -181,7 +181,7 @@ class STL10(VisionDataset):
self.root, self.base_folder, self.folds_list_file) self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f: with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds] str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ') list_idx = np.fromstring(str_idx, dtype=np.int64, sep=' ')
self.data = self.data[list_idx, :, :, :] self.data = self.data[list_idx, :, :, :]
if self.labels is not None: if self.labels is not None:
self.labels = self.labels[list_idx] self.labels = self.labels[list_idx]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment