Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
bbaa1b0d
Commit
bbaa1b0d
authored
Apr 12, 2019
by
Philip Meier
Committed by
Francisco Massa
Apr 12, 2019
Browse files
added support for VisionDataset (#838)
parent
8759f303
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
69 deletions
+24
-69
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+8
-33
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+6
-14
torchvision/datasets/imagenet.py
torchvision/datasets/imagenet.py
+2
-19
torchvision/datasets/sbd.py
torchvision/datasets/sbd.py
+8
-3
No files found.
torchvision/datasets/caltech.py
View file @
bbaa1b0d
...
@@ -2,19 +2,12 @@ from __future__ import print_function
...
@@ -2,19 +2,12 @@ from __future__ import print_function
from
PIL
import
Image
from
PIL
import
Image
import
os
import
os
import
os.path
import
os.path
import
numpy
as
np
import
sys
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
import
collections
import
torch.utils.data
as
data
from
.vision
import
VisionDataset
from
.utils
import
download_url
,
check_integrity
,
makedir_exist_ok
from
.utils
import
download_url
,
makedir_exist_ok
class
Caltech101
(
data
.
Dataset
):
class
Caltech101
(
Vision
Dataset
):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
Args:
Args:
...
@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
...
@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
def
__init__
(
self
,
root
,
target_type
=
"category"
,
def
__init__
(
self
,
root
,
target_type
=
"category"
,
transform
=
None
,
target_transform
=
None
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
download
=
False
):
s
elf
.
root
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
root
)
,
"
caltech101
"
)
s
uper
(
Caltech101
,
self
).
__init__
(
os
.
path
.
join
(
root
,
'
caltech101
'
)
)
makedir_exist_ok
(
self
.
root
)
makedir_exist_ok
(
self
.
root
)
if
isinstance
(
target_type
,
list
):
if
isinstance
(
target_type
,
list
):
self
.
target_type
=
target_type
self
.
target_type
=
target_type
...
@@ -138,19 +131,11 @@ class Caltech101(data.Dataset):
...
@@ -138,19 +131,11 @@ class Caltech101(data.Dataset):
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
"101_Annotations.tar"
),
"r:"
)
as
tar
:
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
"101_Annotations.tar"
),
"r:"
)
as
tar
:
tar
.
extractall
(
path
=
self
.
root
)
tar
.
extractall
(
path
=
self
.
root
)
def
__repr__
(
self
):
def
extra_repr
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
return
"Target type: {target_type}"
.
format
(
**
self
.
__dict__
)
fmt_str
+=
' Number of datapoints: {}
\n
'
.
format
(
self
.
__len__
())
fmt_str
+=
' Target type: {}
\n
'
.
format
(
self
.
target_type
)
fmt_str
+=
' Root Location: {}
\n
'
.
format
(
self
.
root
)
tmp
=
' Transforms (if any): '
fmt_str
+=
'{0}{1}
\n
'
.
format
(
tmp
,
self
.
transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
tmp
=
' Target Transforms (if any): '
fmt_str
+=
'{0}{1}'
.
format
(
tmp
,
self
.
target_transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
return
fmt_str
class
Caltech256
(
data
.
Dataset
):
class
Caltech256
(
Vision
Dataset
):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
Args:
Args:
...
@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
...
@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
def
__init__
(
self
,
root
,
def
__init__
(
self
,
root
,
transform
=
None
,
target_transform
=
None
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
download
=
False
):
s
elf
.
root
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
root
)
,
"
caltech256
"
)
s
uper
(
Caltech256
,
self
).
__init__
(
os
.
path
.
join
(
root
,
'
caltech256
'
)
)
makedir_exist_ok
(
self
.
root
)
makedir_exist_ok
(
self
.
root
)
self
.
transform
=
transform
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
target_transform
=
target_transform
...
@@ -233,13 +218,3 @@ class Caltech256(data.Dataset):
...
@@ -233,13 +218,3 @@ class Caltech256(data.Dataset):
# extract file
# extract file
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories.tar"
),
"r:"
)
as
tar
:
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories.tar"
),
"r:"
)
as
tar
:
tar
.
extractall
(
path
=
self
.
root
)
tar
.
extractall
(
path
=
self
.
root
)
def
__repr__
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
fmt_str
+=
' Number of datapoints: {}
\n
'
.
format
(
self
.
__len__
())
fmt_str
+=
' Root Location: {}
\n
'
.
format
(
self
.
root
)
tmp
=
' Transforms (if any): '
fmt_str
+=
'{0}{1}
\n
'
.
format
(
tmp
,
self
.
transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
tmp
=
' Target Transforms (if any): '
fmt_str
+=
'{0}{1}'
.
format
(
tmp
,
self
.
target_transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
return
fmt_str
torchvision/datasets/celeba.py
View file @
bbaa1b0d
import
torch
import
torch
import
torch.utils.data
as
data
import
os
import
os
import
PIL
import
PIL
from
.vision
import
VisionDataset
from
.utils
import
download_file_from_google_drive
,
check_integrity
from
.utils
import
download_file_from_google_drive
,
check_integrity
class
CelebA
(
data
.
Dataset
):
class
CelebA
(
Vision
Dataset
):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
Args:
...
@@ -53,7 +53,7 @@ class CelebA(data.Dataset):
...
@@ -53,7 +53,7 @@ class CelebA(data.Dataset):
transform
=
None
,
target_transform
=
None
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
download
=
False
):
import
pandas
import
pandas
s
elf
.
root
=
os
.
path
.
expanduser
(
root
)
s
uper
(
CelebA
,
self
).
__init__
(
root
)
self
.
split
=
split
self
.
split
=
split
if
isinstance
(
target_type
,
list
):
if
isinstance
(
target_type
,
list
):
self
.
target_type
=
target_type
self
.
target_type
=
target_type
...
@@ -158,14 +158,6 @@ class CelebA(data.Dataset):
...
@@ -158,14 +158,6 @@ class CelebA(data.Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
attr
)
return
len
(
self
.
attr
)
def
__repr__
(
self
):
def
extra_repr
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
lines
=
[
"Target type: {target_type}"
,
"Split: {split}"
]
fmt_str
+=
' Number of datapoints: {}
\n
'
.
format
(
self
.
__len__
())
return
'
\n
'
.
join
(
lines
).
format
(
**
self
.
__dict__
)
fmt_str
+=
' Target type: {}
\n
'
.
format
(
self
.
target_type
)
fmt_str
+=
' Split: {}
\n
'
.
format
(
self
.
split
)
fmt_str
+=
' Root Location: {}
\n
'
.
format
(
self
.
root
)
tmp
=
' Transforms (if any): '
fmt_str
+=
'{0}{1}
\n
'
.
format
(
tmp
,
self
.
transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
tmp
=
' Target Transforms (if any): '
fmt_str
+=
'{0}{1}'
.
format
(
tmp
,
self
.
target_transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
return
fmt_str
torchvision/datasets/imagenet.py
View file @
bbaa1b0d
...
@@ -132,25 +132,8 @@ class ImageNet(ImageFolder):
...
@@ -132,25 +132,8 @@ class ImageNet(ImageFolder):
def
split_folder
(
self
):
def
split_folder
(
self
):
return
os
.
path
.
join
(
self
.
root
,
self
.
split
)
return
os
.
path
.
join
(
self
.
root
,
self
.
split
)
def
__repr__
(
self
):
def
extra_repr
(
self
):
head
=
"Dataset "
+
self
.
__class__
.
__name__
return
"Split: {split}"
.
format
(
**
self
.
__dict__
)
body
=
[
"Number of datapoints: {}"
.
format
(
self
.
__len__
())]
if
self
.
root
is
not
None
:
body
.
append
(
"Root location: {}"
.
format
(
self
.
root
))
body
+=
[
"Split: {}"
.
format
(
self
.
split
)]
if
hasattr
(
self
,
'transform'
)
and
self
.
transform
is
not
None
:
body
+=
self
.
_format_transform_repr
(
self
.
transform
,
"Transforms: "
)
if
hasattr
(
self
,
'target_transform'
)
and
self
.
target_transform
is
not
None
:
body
+=
self
.
_format_transform_repr
(
self
.
target_transform
,
"Target transforms: "
)
lines
=
[
head
]
+
[
" "
*
4
+
line
for
line
in
body
]
return
'
\n
'
.
join
(
lines
)
def
_format_transform_repr
(
self
,
transform
,
head
):
lines
=
transform
.
__repr__
().
splitlines
()
return
([
"{}{}"
.
format
(
head
,
lines
[
0
])]
+
[
"{}{}"
.
format
(
" "
*
len
(
head
),
line
)
for
line
in
lines
[
1
:]])
def
extract_tar
(
src
,
dest
=
None
,
gzip
=
None
,
delete
=
False
):
def
extract_tar
(
src
,
dest
=
None
,
gzip
=
None
,
delete
=
False
):
...
...
torchvision/datasets/sbd.py
View file @
bbaa1b0d
import
os
import
os
import
torch.utils.data
as
data
from
.vision
import
VisionDataset
import
numpy
as
np
import
numpy
as
np
...
@@ -8,7 +8,7 @@ from .utils import download_url
...
@@ -8,7 +8,7 @@ from .utils import download_url
from
.voc
import
download_extract
from
.voc
import
download_extract
class
SBDataset
(
data
.
Dataset
):
class
SBDataset
(
Vision
Dataset
):
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
...
@@ -62,10 +62,11 @@ class SBDataset(data.Dataset):
...
@@ -62,10 +62,11 @@ class SBDataset(data.Dataset):
raise
RuntimeError
(
"Scipy is not found. This dataset needs to have scipy installed: "
raise
RuntimeError
(
"Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy"
)
"pip install scipy"
)
super
(
SBDataset
,
self
).
__init__
(
root
)
if
mode
not
in
(
"segmentation"
,
"boundaries"
):
if
mode
not
in
(
"segmentation"
,
"boundaries"
):
raise
ValueError
(
"Argument mode should be 'segmentation' or 'boundaries'"
)
raise
ValueError
(
"Argument mode should be 'segmentation' or 'boundaries'"
)
self
.
root
=
os
.
path
.
expanduser
(
root
)
self
.
xy_transform
=
xy_transform
self
.
xy_transform
=
xy_transform
self
.
image_set
=
image_set
self
.
image_set
=
image_set
self
.
mode
=
mode
self
.
mode
=
mode
...
@@ -121,3 +122,7 @@ class SBDataset(data.Dataset):
...
@@ -121,3 +122,7 @@ class SBDataset(data.Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
images
)
return
len
(
self
.
images
)
def
extra_repr
(
self
):
lines
=
[
"Image set: {image_set}"
,
"Mode: {mode}"
]
return
'
\n
'
.
join
(
lines
).
format
(
**
self
.
__dict__
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment