Commit 918fdffd authored by Alykhan Tejani's avatar Alykhan Tejani Committed by GitHub
Browse files

Merge pull request #194 from vabh/master

Update svhn.py to be consistent with other datasets
parents 5c094092 9ff1b4e8
...@@ -66,7 +66,16 @@ class SVHN(data.Dataset): ...@@ -66,7 +66,16 @@ class SVHN(data.Dataset):
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
self.data = loaded_mat['X'] self.data = loaded_mat['X']
self.labels = loaded_mat['y'] # loading from the .mat file gives an np array of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
self.labels = loaded_mat['y'].astype(np.int64).squeeze()
# the svhn dataset assigns the class label "10" to the digit 0
# this makes it inconsistent with several loss functions
# which expect the class labels to be in the range [0, C-1]
np.place(self.labels, self.labels == 10, 0)
self.data = np.transpose(self.data, (3, 2, 0, 1)) self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index): def __getitem__(self, index):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment