Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b13bac51
Commit
b13bac51
authored
Mar 17, 2017
by
soumith
Committed by
Soumith Chintala
Mar 17, 2017
Browse files
refactored phototour to use utils
parent
32460f52
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
52 additions
and
61 deletions
+52
-61
setup.cfg
setup.cfg
+1
-1
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+1
-0
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+1
-1
torchvision/datasets/lsun.py
torchvision/datasets/lsun.py
+0
-12
torchvision/datasets/phototour.py
torchvision/datasets/phototour.py
+32
-43
torchvision/datasets/stl10.py
torchvision/datasets/stl10.py
+4
-2
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+1
-0
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+4
-2
torchvision/models/inception.py
torchvision/models/inception.py
+8
-0
No files found.
setup.cfg
View file @
b13bac51
...
@@ -6,5 +6,5 @@ max-line-length = 120
...
@@ -6,5 +6,5 @@ max-line-length = 120
[flake8]
[flake8]
max-line-length = 120
max-line-length = 120
ignore = F401,F403
ignore = F401,
E402,
F403
exclude = venv
exclude = venv
torchvision/datasets/__init__.py
View file @
b13bac51
...
@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
...
@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
from
.stl10
import
STL10
from
.stl10
import
STL10
from
.mnist
import
MNIST
from
.mnist
import
MNIST
from
.svhn
import
SVHN
from
.svhn
import
SVHN
from
.phototour
import
PhotoTour
__all__
=
(
'LSUN'
,
'LSUNClass'
,
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
'ImageFolder'
,
...
...
torchvision/datasets/cifar.py
View file @
b13bac51
from
__future__
import
print_function
from
__future__
import
print_function
import
torch.utils.data
as
data
from
PIL
import
Image
from
PIL
import
Image
import
os
import
os
import
os.path
import
os.path
...
@@ -11,6 +10,7 @@ if sys.version_info[0] == 2:
...
@@ -11,6 +10,7 @@ if sys.version_info[0] == 2:
else
:
else
:
import
pickle
import
pickle
import
torch.utils.data
as
data
from
.utils
import
download_url
,
check_integrity
from
.utils
import
download_url
,
check_integrity
...
...
torchvision/datasets/lsun.py
View file @
b13bac51
...
@@ -127,15 +127,3 @@ class LSUN(data.Dataset):
...
@@ -127,15 +127,3 @@ class LSUN(data.Dataset):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
' ('
+
self
.
db_path
+
')'
return
self
.
__class__
.
__name__
+
' ('
+
self
.
db_path
+
')'
if
__name__
==
'__main__'
:
# lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
# a = lsun[0]
lsun
=
LSUN
(
db_path
=
'/home/soumith/local/lsun/train'
,
classes
=
[
'bedroom_train'
,
'church_outdoor_train'
])
print
(
lsun
.
classes
)
print
(
lsun
.
dbs
)
a
,
t
=
lsun
[
len
(
lsun
)
-
1
]
print
(
a
)
print
(
t
)
torchvision/datasets/phototour.py
View file @
b13bac51
...
@@ -6,12 +6,26 @@ from PIL import Image
...
@@ -6,12 +6,26 @@ from PIL import Image
import
torch
import
torch
import
torch.utils.data
as
data
import
torch.utils.data
as
data
from
.utils
import
download_url
,
check_integrity
class
PhotoTour
(
data
.
Dataset
):
class
PhotoTour
(
data
.
Dataset
):
urls
=
{
urls
=
{
'notredame'
:
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip'
,
'notredame'
:
[
'yosemite'
:
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip'
,
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip'
,
'liberty'
:
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip'
'notredame.zip'
,
'509eda8535847b8c0a90bbb210c83484'
],
'yosemite'
:
[
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip'
,
'yosemite.zip'
,
'533b2e8eb7ede31be40abc317b2fd4f0'
],
'liberty'
:
[
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip'
,
'liberty.zip'
,
'fdd9152f138ea5ef2091746689176414'
],
}
}
mean
=
{
'notredame'
:
0.4854
,
'yosemite'
:
0.4844
,
'liberty'
:
0.4437
}
mean
=
{
'notredame'
:
0.4854
,
'yosemite'
:
0.4844
,
'liberty'
:
0.4437
}
std
=
{
'notredame'
:
0.1864
,
'yosemite'
:
0.1818
,
'liberty'
:
0.2019
}
std
=
{
'notredame'
:
0.1864
,
'yosemite'
:
0.1818
,
'liberty'
:
0.2019
}
...
@@ -37,7 +51,7 @@ class PhotoTour(data.Dataset):
...
@@ -37,7 +51,7 @@ class PhotoTour(data.Dataset):
if
download
:
if
download
:
self
.
download
()
self
.
download
()
if
not
self
.
_check_exists
():
if
not
self
.
_check_
datafile_
exists
():
raise
RuntimeError
(
'Dataset not found.'
+
raise
RuntimeError
(
'Dataset not found.'
+
' You can use download=True to download it'
)
' You can use download=True to download it'
)
...
@@ -62,59 +76,45 @@ class PhotoTour(data.Dataset):
...
@@ -62,59 +76,45 @@ class PhotoTour(data.Dataset):
return
self
.
lens
[
self
.
name
]
return
self
.
lens
[
self
.
name
]
return
len
(
self
.
matches
)
return
len
(
self
.
matches
)
def
_check_exists
(
self
):
def
_check_
datafile_
exists
(
self
):
return
os
.
path
.
exists
(
self
.
data_file
)
return
os
.
path
.
exists
(
self
.
data_file
)
def
_check_downloaded
(
self
):
def
_check_downloaded
(
self
):
return
os
.
path
.
exists
(
self
.
data_dir
)
return
os
.
path
.
exists
(
self
.
data_dir
)
def
download
(
self
):
def
download
(
self
):
from
six.moves
import
urllib
if
self
.
_check_datafile_exists
():
print
(
'
\n
-- Loading PhotoTour dataset: {}
\n
'
.
format
(
self
.
name
))
if
self
.
_check_exists
():
print
(
'# Found cached data {}'
.
format
(
self
.
data_file
))
print
(
'# Found cached data {}'
.
format
(
self
.
data_file
))
return
return
if
not
self
.
_check_downloaded
():
if
not
self
.
_check_downloaded
():
# download files
# download files
url
=
self
.
urls
[
self
.
name
]
url
=
self
.
urls
[
self
.
name
][
0
]
filename
=
url
.
rpartition
(
'/'
)[
2
]
filename
=
self
.
urls
[
self
.
name
][
1
]
file_path
=
os
.
path
.
join
(
self
.
root
,
filename
)
md5
=
self
.
urls
[
self
.
name
][
2
]
fpath
=
os
.
path
.
join
(
self
.
root
,
filename
)
try
:
os
.
makedirs
(
self
.
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
print
(
'# Downloading {} into {}
\n\n
It might take while.'
download_url
(
url
,
self
.
root
,
filename
,
md5
)
' Please grab yourself a coffee and relax.'
.
format
(
url
,
file_path
))
urllib
.
request
.
urlretrieve
(
url
,
file_path
)
assert
os
.
path
.
exists
(
file_path
)
print
(
'# Extracting data {}
\n
'
.
format
(
self
.
data_down
))
print
(
'# Extracting data {}
\n
'
.
format
(
self
.
data_down
))
import
zipfile
import
zipfile
with
zipfile
.
ZipFile
(
f
ile_
path
,
'r'
)
as
z
:
with
zipfile
.
ZipFile
(
fpath
,
'r'
)
as
z
:
z
.
extractall
(
self
.
data_dir
)
z
.
extractall
(
self
.
data_dir
)
os
.
unlink
(
file_path
)
os
.
unlink
(
fpath
)
# process and save as torch files
# process and save as torch files
print
(
'# Caching data {}'
.
format
(
self
.
data_file
))
print
(
'# Caching data {}'
.
format
(
self
.
data_file
))
data
_
set
=
(
dataset
=
(
read_image_file
(
self
.
data_dir
,
self
.
image_ext
,
self
.
lens
[
self
.
name
]),
read_image_file
(
self
.
data_dir
,
self
.
image_ext
,
self
.
lens
[
self
.
name
]),
read_info_file
(
self
.
data_dir
,
self
.
info_file
),
read_info_file
(
self
.
data_dir
,
self
.
info_file
),
read_matches_files
(
self
.
data_dir
,
self
.
matches_files
)
read_matches_files
(
self
.
data_dir
,
self
.
matches_files
)
)
)
with
open
(
self
.
data_file
,
'wb'
)
as
f
:
with
open
(
self
.
data_file
,
'wb'
)
as
f
:
torch
.
save
(
data
_
set
,
f
)
torch
.
save
(
dataset
,
f
)
def
read_image_file
(
data_dir
,
image_ext
,
n
):
def
read_image_file
(
data_dir
,
image_ext
,
n
):
...
@@ -138,8 +138,8 @@ def read_image_file(data_dir, image_ext, n):
...
@@ -138,8 +138,8 @@ def read_image_file(data_dir, image_ext, n):
patches
=
[]
patches
=
[]
list_files
=
find_files
(
data_dir
,
image_ext
)
list_files
=
find_files
(
data_dir
,
image_ext
)
for
f
ile_
path
in
list_files
:
for
fpath
in
list_files
:
img
=
Image
.
open
(
f
ile_
path
)
img
=
Image
.
open
(
fpath
)
for
y
in
range
(
0
,
1024
,
64
):
for
y
in
range
(
0
,
1024
,
64
):
for
x
in
range
(
0
,
1024
,
64
):
for
x
in
range
(
0
,
1024
,
64
):
patch
=
img
.
crop
((
x
,
y
,
x
+
64
,
y
+
64
))
patch
=
img
.
crop
((
x
,
y
,
x
+
64
,
y
+
64
))
...
@@ -168,14 +168,3 @@ def read_matches_files(data_dir, matches_file):
...
@@ -168,14 +168,3 @@ def read_matches_files(data_dir, matches_file):
l
=
line
.
split
()
l
=
line
.
split
()
matches
.
append
([
int
(
l
[
0
]),
int
(
l
[
3
]),
int
(
l
[
1
]
==
l
[
4
])])
matches
.
append
([
int
(
l
[
0
]),
int
(
l
[
3
]),
int
(
l
[
1
]
==
l
[
4
])])
return
torch
.
LongTensor
(
matches
)
return
torch
.
LongTensor
(
matches
)
if
__name__
==
'__main__'
:
dataset
=
PhotoTour
(
root
=
'/home/eriba/datasets/patches_dataset'
,
name
=
'notredame'
,
download
=
True
)
print
(
'Loaded PhotoTour: {} with {} images.'
.
format
(
dataset
.
name
,
len
(
dataset
.
data
)))
assert
len
(
dataset
.
data
)
==
len
(
dataset
.
labels
)
torchvision/datasets/stl10.py
View file @
b13bac51
...
@@ -26,7 +26,8 @@ class STL10(CIFAR10):
...
@@ -26,7 +26,8 @@ class STL10(CIFAR10):
[
'test_y.bin'
,
'36f9794fa4beb8a2c72628de14fa638e'
]
[
'test_y.bin'
,
'36f9794fa4beb8a2c72628de14fa638e'
]
]
]
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
root
=
root
self
.
transform
=
transform
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
target_transform
=
target_transform
...
@@ -37,7 +38,8 @@ class STL10(CIFAR10):
...
@@ -37,7 +38,8 @@ class STL10(CIFAR10):
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
raise
RuntimeError
(
'Dataset not found or corrupted. You can use download=True to download it'
)
'Dataset not found or corrupted. '
'You can use download=True to download it'
)
# now load the picked numpy arrays
# now load the picked numpy arrays
if
self
.
split
==
'train'
:
if
self
.
split
==
'train'
:
...
...
torchvision/datasets/svhn.py
View file @
b13bac51
...
@@ -8,6 +8,7 @@ import numpy as np
...
@@ -8,6 +8,7 @@ import numpy as np
import
sys
import
sys
from
.utils
import
download_url
,
check_integrity
from
.utils
import
download_url
,
check_integrity
class
SVHN
(
data
.
Dataset
):
class
SVHN
(
data
.
Dataset
):
url
=
""
url
=
""
filename
=
""
filename
=
""
...
...
torchvision/datasets/utils.py
View file @
b13bac51
import
os
import
os
import
os.path
import
hashlib
import
hashlib
import
errno
import
errno
def
check_integrity
(
fpath
,
md5
):
def
check_integrity
(
fpath
,
md5
):
if
not
os
.
path
.
isfile
(
fpath
):
if
not
os
.
path
.
isfile
(
fpath
):
return
False
return
False
md5o
=
hashlib
.
md5
()
md5o
=
hashlib
.
md5
()
with
open
(
fpath
,
'rb'
)
as
f
:
with
open
(
fpath
,
'rb'
)
as
f
:
# read in 1MB chunks
# read in 1MB chunks
for
chunk
in
iter
(
lambda
:
f
.
read
(
1024
*
1024
*
1024
),
b
''
):
for
chunk
in
iter
(
lambda
:
f
.
read
(
1024
*
1024
*
1024
),
b
''
):
md5o
.
update
(
chunk
)
md5o
.
update
(
chunk
)
...
@@ -16,7 +18,7 @@ def check_integrity(fpath, md5):
...
@@ -16,7 +18,7 @@ def check_integrity(fpath, md5):
return
True
return
True
def
download_url
(
url
,
root
,
filename
,
md5
=
None
):
def
download_url
(
url
,
root
,
filename
,
md5
):
from
six.moves
import
urllib
from
six.moves
import
urllib
fpath
=
os
.
path
.
join
(
root
,
filename
)
fpath
=
os
.
path
.
join
(
root
,
filename
)
...
...
torchvision/models/inception.py
View file @
b13bac51
...
@@ -31,6 +31,7 @@ def inception_v3(pretrained=False, **kwargs):
...
@@ -31,6 +31,7 @@ def inception_v3(pretrained=False, **kwargs):
class
Inception3
(
nn
.
Module
):
class
Inception3
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
1000
,
aux_logits
=
True
,
transform_input
=
False
):
def
__init__
(
self
,
num_classes
=
1000
,
aux_logits
=
True
,
transform_input
=
False
):
super
(
Inception3
,
self
).
__init__
()
super
(
Inception3
,
self
).
__init__
()
self
.
aux_logits
=
aux_logits
self
.
aux_logits
=
aux_logits
...
@@ -126,6 +127,7 @@ class Inception3(nn.Module):
...
@@ -126,6 +127,7 @@ class Inception3(nn.Module):
class
InceptionA
(
nn
.
Module
):
class
InceptionA
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
pool_features
):
def
__init__
(
self
,
in_channels
,
pool_features
):
super
(
InceptionA
,
self
).
__init__
()
super
(
InceptionA
,
self
).
__init__
()
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
64
,
kernel_size
=
1
)
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
64
,
kernel_size
=
1
)
...
@@ -157,6 +159,7 @@ class InceptionA(nn.Module):
...
@@ -157,6 +159,7 @@ class InceptionA(nn.Module):
class
InceptionB
(
nn
.
Module
):
class
InceptionB
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
):
super
(
InceptionB
,
self
).
__init__
()
super
(
InceptionB
,
self
).
__init__
()
self
.
branch3x3
=
BasicConv2d
(
in_channels
,
384
,
kernel_size
=
3
,
stride
=
2
)
self
.
branch3x3
=
BasicConv2d
(
in_channels
,
384
,
kernel_size
=
3
,
stride
=
2
)
...
@@ -179,6 +182,7 @@ class InceptionB(nn.Module):
...
@@ -179,6 +182,7 @@ class InceptionB(nn.Module):
class
InceptionC
(
nn
.
Module
):
class
InceptionC
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
channels_7x7
):
def
__init__
(
self
,
in_channels
,
channels_7x7
):
super
(
InceptionC
,
self
).
__init__
()
super
(
InceptionC
,
self
).
__init__
()
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
192
,
kernel_size
=
1
)
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
192
,
kernel_size
=
1
)
...
@@ -217,6 +221,7 @@ class InceptionC(nn.Module):
...
@@ -217,6 +221,7 @@ class InceptionC(nn.Module):
class
InceptionD
(
nn
.
Module
):
class
InceptionD
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
):
super
(
InceptionD
,
self
).
__init__
()
super
(
InceptionD
,
self
).
__init__
()
self
.
branch3x3_1
=
BasicConv2d
(
in_channels
,
192
,
kernel_size
=
1
)
self
.
branch3x3_1
=
BasicConv2d
(
in_channels
,
192
,
kernel_size
=
1
)
...
@@ -242,6 +247,7 @@ class InceptionD(nn.Module):
...
@@ -242,6 +247,7 @@ class InceptionD(nn.Module):
class
InceptionE
(
nn
.
Module
):
class
InceptionE
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
def
__init__
(
self
,
in_channels
):
super
(
InceptionE
,
self
).
__init__
()
super
(
InceptionE
,
self
).
__init__
()
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
320
,
kernel_size
=
1
)
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
320
,
kernel_size
=
1
)
...
@@ -283,6 +289,7 @@ class InceptionE(nn.Module):
...
@@ -283,6 +289,7 @@ class InceptionE(nn.Module):
class
InceptionAux
(
nn
.
Module
):
class
InceptionAux
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
num_classes
):
def
__init__
(
self
,
in_channels
,
num_classes
):
super
(
InceptionAux
,
self
).
__init__
()
super
(
InceptionAux
,
self
).
__init__
()
self
.
conv0
=
BasicConv2d
(
in_channels
,
128
,
kernel_size
=
1
)
self
.
conv0
=
BasicConv2d
(
in_channels
,
128
,
kernel_size
=
1
)
...
@@ -307,6 +314,7 @@ class InceptionAux(nn.Module):
...
@@ -307,6 +314,7 @@ class InceptionAux(nn.Module):
class
BasicConv2d
(
nn
.
Module
):
class
BasicConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
(
BasicConv2d
,
self
).
__init__
()
super
(
BasicConv2d
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
bias
=
False
,
**
kwargs
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
bias
=
False
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment