Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
bbaa1b0d
Commit
bbaa1b0d
authored
Apr 12, 2019
by
Philip Meier
Committed by
Francisco Massa
Apr 12, 2019
Browse files
added support for VisionDataset (#838)
parent
8759f303
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
69 deletions
+24
-69
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+8
-33
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+6
-14
torchvision/datasets/imagenet.py
torchvision/datasets/imagenet.py
+2
-19
torchvision/datasets/sbd.py
torchvision/datasets/sbd.py
+8
-3
No files found.
torchvision/datasets/caltech.py
View file @
bbaa1b0d
...
...
@@ -2,19 +2,12 @@ from __future__ import print_function
from
PIL
import
Image
import
os
import
os.path
import
numpy
as
np
import
sys
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
import
collections
import
torch.utils.data
as
data
from
.utils
import
download_url
,
check_integrity
,
makedir_exist_ok
from
.vision
import
VisionDataset
from
.utils
import
download_url
,
makedir_exist_ok
class
Caltech101
(
data
.
Dataset
):
class
Caltech101
(
Vision
Dataset
):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
Args:
...
...
@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
def
__init__
(
self
,
root
,
target_type
=
"category"
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
s
elf
.
root
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
root
)
,
"
caltech101
"
)
s
uper
(
Caltech101
,
self
).
__init__
(
os
.
path
.
join
(
root
,
'
caltech101
'
)
)
makedir_exist_ok
(
self
.
root
)
if
isinstance
(
target_type
,
list
):
self
.
target_type
=
target_type
...
...
@@ -138,19 +131,11 @@ class Caltech101(data.Dataset):
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
"101_Annotations.tar"
),
"r:"
)
as
tar
:
tar
.
extractall
(
path
=
self
.
root
)
def
__repr__
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
fmt_str
+=
' Number of datapoints: {}
\n
'
.
format
(
self
.
__len__
())
fmt_str
+=
' Target type: {}
\n
'
.
format
(
self
.
target_type
)
fmt_str
+=
' Root Location: {}
\n
'
.
format
(
self
.
root
)
tmp
=
' Transforms (if any): '
fmt_str
+=
'{0}{1}
\n
'
.
format
(
tmp
,
self
.
transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
tmp
=
' Target Transforms (if any): '
fmt_str
+=
'{0}{1}'
.
format
(
tmp
,
self
.
target_transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
return
fmt_str
def
extra_repr
(
self
):
return
"Target type: {target_type}"
.
format
(
**
self
.
__dict__
)
class
Caltech256
(
data
.
Dataset
):
class
Caltech256
(
Vision
Dataset
):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
Args:
...
...
@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
def
__init__
(
self
,
root
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
s
elf
.
root
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
root
)
,
"
caltech256
"
)
s
uper
(
Caltech256
,
self
).
__init__
(
os
.
path
.
join
(
root
,
'
caltech256
'
)
)
makedir_exist_ok
(
self
.
root
)
self
.
transform
=
transform
self
.
target_transform
=
target_transform
...
...
@@ -233,13 +218,3 @@ class Caltech256(data.Dataset):
# extract file
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories.tar"
),
"r:"
)
as
tar
:
tar
.
extractall
(
path
=
self
.
root
)
def
__repr__
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
fmt_str
+=
' Number of datapoints: {}
\n
'
.
format
(
self
.
__len__
())
fmt_str
+=
' Root Location: {}
\n
'
.
format
(
self
.
root
)
tmp
=
' Transforms (if any): '
fmt_str
+=
'{0}{1}
\n
'
.
format
(
tmp
,
self
.
transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
tmp
=
' Target Transforms (if any): '
fmt_str
+=
'{0}{1}'
.
format
(
tmp
,
self
.
target_transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
return
fmt_str
torchvision/datasets/celeba.py
View file @
bbaa1b0d
import
torch
import
torch.utils.data
as
data
import
os
import
PIL
from
.vision
import
VisionDataset
from
.utils
import
download_file_from_google_drive
,
check_integrity
class
CelebA
(
data
.
Dataset
):
class
CelebA
(
Vision
Dataset
):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
...
...
@@ -53,7 +53,7 @@ class CelebA(data.Dataset):
transform
=
None
,
target_transform
=
None
,
download
=
False
):
import
pandas
s
elf
.
root
=
os
.
path
.
expanduser
(
root
)
s
uper
(
CelebA
,
self
).
__init__
(
root
)
self
.
split
=
split
if
isinstance
(
target_type
,
list
):
self
.
target_type
=
target_type
...
...
@@ -158,14 +158,6 @@ class CelebA(data.Dataset):
def
__len__
(
self
):
return
len
(
self
.
attr
)
def
__repr__
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
fmt_str
+=
' Number of datapoints: {}
\n
'
.
format
(
self
.
__len__
())
fmt_str
+=
' Target type: {}
\n
'
.
format
(
self
.
target_type
)
fmt_str
+=
' Split: {}
\n
'
.
format
(
self
.
split
)
fmt_str
+=
' Root Location: {}
\n
'
.
format
(
self
.
root
)
tmp
=
' Transforms (if any): '
fmt_str
+=
'{0}{1}
\n
'
.
format
(
tmp
,
self
.
transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
tmp
=
' Target Transforms (if any): '
fmt_str
+=
'{0}{1}'
.
format
(
tmp
,
self
.
target_transform
.
__repr__
().
replace
(
'
\n
'
,
'
\n
'
+
' '
*
len
(
tmp
)))
return
fmt_str
def
extra_repr
(
self
):
lines
=
[
"Target type: {target_type}"
,
"Split: {split}"
]
return
'
\n
'
.
join
(
lines
).
format
(
**
self
.
__dict__
)
torchvision/datasets/imagenet.py
View file @
bbaa1b0d
...
...
@@ -132,25 +132,8 @@ class ImageNet(ImageFolder):
def
split_folder
(
self
):
return
os
.
path
.
join
(
self
.
root
,
self
.
split
)
def
__repr__
(
self
):
head
=
"Dataset "
+
self
.
__class__
.
__name__
body
=
[
"Number of datapoints: {}"
.
format
(
self
.
__len__
())]
if
self
.
root
is
not
None
:
body
.
append
(
"Root location: {}"
.
format
(
self
.
root
))
body
+=
[
"Split: {}"
.
format
(
self
.
split
)]
if
hasattr
(
self
,
'transform'
)
and
self
.
transform
is
not
None
:
body
+=
self
.
_format_transform_repr
(
self
.
transform
,
"Transforms: "
)
if
hasattr
(
self
,
'target_transform'
)
and
self
.
target_transform
is
not
None
:
body
+=
self
.
_format_transform_repr
(
self
.
target_transform
,
"Target transforms: "
)
lines
=
[
head
]
+
[
" "
*
4
+
line
for
line
in
body
]
return
'
\n
'
.
join
(
lines
)
def
_format_transform_repr
(
self
,
transform
,
head
):
lines
=
transform
.
__repr__
().
splitlines
()
return
([
"{}{}"
.
format
(
head
,
lines
[
0
])]
+
[
"{}{}"
.
format
(
" "
*
len
(
head
),
line
)
for
line
in
lines
[
1
:]])
def
extra_repr
(
self
):
return
"Split: {split}"
.
format
(
**
self
.
__dict__
)
def
extract_tar
(
src
,
dest
=
None
,
gzip
=
None
,
delete
=
False
):
...
...
torchvision/datasets/sbd.py
View file @
bbaa1b0d
import
os
import
torch.utils.data
as
data
from
.vision
import
VisionDataset
import
numpy
as
np
...
...
@@ -8,7 +8,7 @@ from .utils import download_url
from
.voc
import
download_extract
class
SBDataset
(
data
.
Dataset
):
class
SBDataset
(
Vision
Dataset
):
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
...
...
@@ -62,10 +62,11 @@ class SBDataset(data.Dataset):
raise
RuntimeError
(
"Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy"
)
super
(
SBDataset
,
self
).
__init__
(
root
)
if
mode
not
in
(
"segmentation"
,
"boundaries"
):
raise
ValueError
(
"Argument mode should be 'segmentation' or 'boundaries'"
)
self
.
root
=
os
.
path
.
expanduser
(
root
)
self
.
xy_transform
=
xy_transform
self
.
image_set
=
image_set
self
.
mode
=
mode
...
...
@@ -121,3 +122,7 @@ class SBDataset(data.Dataset):
def
__len__
(
self
):
return
len
(
self
.
images
)
def
extra_repr
(
self
):
lines
=
[
"Image set: {image_set}"
,
"Mode: {mode}"
]
return
'
\n
'
.
join
(
lines
).
format
(
**
self
.
__dict__
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment