Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
918fdffd
Commit
918fdffd
authored
Sep 20, 2017
by
Alykhan Tejani
Committed by
GitHub
Sep 20, 2017
Browse files
Merge pull request #194 from vabh/master
Update svhn.py to be consistent with other datasets
parents
5c094092
9ff1b4e8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
1 deletion
+10
-1
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+10
-1
No files found.
torchvision/datasets/svhn.py
View file @
918fdffd
...
@@ -66,7 +66,16 @@ class SVHN(data.Dataset):
...
@@ -66,7 +66,16 @@ class SVHN(data.Dataset):
loaded_mat
=
sio
.
loadmat
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
))
loaded_mat
=
sio
.
loadmat
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
))
self
.
data
=
loaded_mat
[
'X'
]
self
.
data
=
loaded_mat
[
'X'
]
self
.
labels
=
loaded_mat
[
'y'
]
# loading from the .mat file gives an np array of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
self
.
labels
=
loaded_mat
[
'y'
].
astype
(
np
.
int64
).
squeeze
()
# the svhn dataset assigns the class label "10" to the digit 0
# this makes it inconsistent with several loss functions
# which expect the class labels to be in the range [0, C-1]
np
.
place
(
self
.
labels
,
self
.
labels
==
10
,
0
)
self
.
data
=
np
.
transpose
(
self
.
data
,
(
3
,
2
,
0
,
1
))
self
.
data
=
np
.
transpose
(
self
.
data
,
(
3
,
2
,
0
,
1
))
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment