"vscode:/vscode.git/clone" did not exist on "ca4a46eff11adc0881351db3e45378d23b521b92"
Commit 9958ecbe authored by Anuvabh Dutt's avatar Anuvabh Dutt Committed by GitHub
Browse files

Update svhn.py to be consistent with other datasets

Make the labels be 1d tensors of type `np.int64`.
Assign label `0`  to data samples of digit 0. The original dataset assigns them label `10`.
parent 08b1f59f
...@@ -66,7 +66,16 @@ class SVHN(data.Dataset): ...@@ -66,7 +66,16 @@ class SVHN(data.Dataset):
loaded_mat = sio.loadmat(os.path.join(root, self.filename)) loaded_mat = sio.loadmat(os.path.join(root, self.filename))
self.data = loaded_mat['X'] self.data = loaded_mat['X']
self.labels = loaded_mat['y'] # loading from the .mat file gives an np array of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
self.labels = loaded_mat['y'].astype(np.int64).squeeze()
# the svhn dataset assigns the class label "10" to the digit 0
# this makes it inconsistent with several loss functions
# which expect the class labels to be in the range [0, C-1]
np.place(self.labels, self.labels==10, 0)
self.data = np.transpose(self.data, (3, 2, 0, 1)) self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index): def __getitem__(self, index):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment