Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
00ce2d0f
Commit
00ce2d0f
authored
Mar 17, 2017
by
soumith
Committed by
Soumith Chintala
Mar 17, 2017
Browse files
refactor download and md5-checking utilities
parent
c7a39ba9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
29 deletions
+40
-29
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+11
-29
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+29
-0
No files found.
torchvision/datasets/cifar.py
View file @
00ce2d0f
...
@@ -11,6 +11,8 @@ if sys.version_info[0] == 2:
...
@@ -11,6 +11,8 @@ if sys.version_info[0] == 2:
else
:
else
:
import
pickle
import
pickle
import
.utils
as
utils
class
CIFAR10
(
data
.
Dataset
):
class
CIFAR10
(
data
.
Dataset
):
base_folder
=
'cifar-10-batches-py'
base_folder
=
'cifar-10-batches-py'
...
@@ -29,7 +31,9 @@ class CIFAR10(data.Dataset):
...
@@ -29,7 +31,9 @@ class CIFAR10(data.Dataset):
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
]
]
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
root
=
root
self
.
transform
=
transform
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
target_transform
=
target_transform
...
@@ -106,55 +110,33 @@ class CIFAR10(data.Dataset):
...
@@ -106,55 +110,33 @@ class CIFAR10(data.Dataset):
return
10000
return
10000
def
_check_integrity
(
self
):
def
_check_integrity
(
self
):
import
hashlib
root
=
self
.
root
root
=
self
.
root
for
fentry
in
(
self
.
train_list
+
self
.
test_list
):
for
fentry
in
(
self
.
train_list
+
self
.
test_list
):
filename
,
md5
=
fentry
[
0
],
fentry
[
1
]
filename
,
md5
=
fentry
[
0
],
fentry
[
1
]
fpath
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
filename
)
fpath
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
filename
)
if
not
os
.
path
.
isfile
(
fpath
):
if
not
utils
.
check_integrity
(
fpath
,
md5
):
return
False
md5c
=
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
if
md5c
!=
md5
:
return
False
return
False
return
True
return
True
def
download
(
self
):
def
download
(
self
):
from
six.moves
import
urllib
import
tarfile
import
tarfile
import
hashlib
root
=
self
.
root
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
try
:
os
.
makedirs
(
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
if
self
.
_check_integrity
():
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
print
(
'Files already downloaded and verified'
)
return
return
# downloads file
root
=
self
.
root
if
os
.
path
.
isfile
(
fpath
)
and
\
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
==
self
.
tgz_md5
:
# download
print
(
'Using downloaded file: '
+
fpath
)
utils
.
download
(
self
.
url
,
root
,
self
.
filename
,
self
.
tgz_md5
)
else
:
print
(
'Downloading '
+
self
.
url
+
' to '
+
fpath
)
urllib
.
request
.
urlretrieve
(
self
.
url
,
fpath
)
# extract file
# extract file
cwd
=
os
.
getcwd
()
cwd
=
os
.
getcwd
()
print
(
'Extracting tar file'
)
tar
=
tarfile
.
open
(
os
.
path
.
join
(
root
,
self
.
filename
),
"r:gz"
)
tar
=
tarfile
.
open
(
fpath
,
"r:gz"
)
os
.
chdir
(
root
)
os
.
chdir
(
root
)
tar
.
extractall
()
tar
.
extractall
()
tar
.
close
()
tar
.
close
()
os
.
chdir
(
cwd
)
os
.
chdir
(
cwd
)
print
(
'Done!'
)
class
CIFAR100
(
CIFAR10
):
class
CIFAR100
(
CIFAR10
):
...
...
torchvision/datasets/utils.py
0 → 100644
View file @
00ce2d0f
def
check_integrity
(
fpath
,
md5
):
import
hashlib
if
not
os
.
path
.
isfile
(
fpath
):
return
False
md5c
=
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
if
md5c
!=
md5
:
return
False
return
True
def
download
(
url
,
root
,
filename
,
md5
=
None
):
from
six.moves
import
urllib
fpath
=
os
.
path
.
join
(
root
,
filename
)
try
:
os
.
makedirs
(
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
# downloads file
if
os
.
path
.
isfile
(
fpath
)
and
check_integrity
(
fpath
,
md5
):
print
(
'Using downloaded and verified file: '
+
fpath
)
else
:
print
(
'Downloading '
+
url
+
' to '
+
fpath
)
urllib
.
request
.
urlretrieve
(
url
,
fpath
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment