Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
31245cb8
Unverified
Commit
31245cb8
authored
Jul 31, 2020
by
Philip Meier
Committed by
GitHub
Jul 31, 2020
Browse files
caltech (#2521)
parent
e1c50d9c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
15 deletions
+28
-15
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+28
-15
No files found.
torchvision/datasets/caltech.py
View file @
31245cb8
from
PIL
import
Image
import
os
import
os.path
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
,
Tuple
from
.vision
import
VisionDataset
from
.utils
import
download_and_extract_archive
,
verify_str_arg
...
...
@@ -29,8 +30,14 @@ class Caltech101(VisionDataset):
downloaded again.
"""
def
__init__
(
self
,
root
,
target_type
=
"category"
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
:
str
,
target_type
:
Union
[
List
[
str
],
str
]
=
"category"
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
Caltech101
,
self
).
__init__
(
os
.
path
.
join
(
root
,
'caltech101'
),
transform
=
transform
,
target_transform
=
target_transform
)
...
...
@@ -59,14 +66,14 @@ class Caltech101(VisionDataset):
"airplanes"
:
"Airplanes_Side_2"
}
self
.
annotation_categories
=
list
(
map
(
lambda
x
:
name_map
[
x
]
if
x
in
name_map
else
x
,
self
.
categories
))
self
.
index
=
[]
self
.
index
:
List
[
int
]
=
[]
self
.
y
=
[]
for
(
i
,
c
)
in
enumerate
(
self
.
categories
):
n
=
len
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"101_ObjectCategories"
,
c
)))
self
.
index
.
extend
(
range
(
1
,
n
+
1
))
self
.
y
.
extend
(
n
*
[
i
])
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
Args:
index (int): Index
...
...
@@ -81,7 +88,7 @@ class Caltech101(VisionDataset):
self
.
categories
[
self
.
y
[
index
]],
"image_{:04d}.jpg"
.
format
(
self
.
index
[
index
])))
target
=
[]
target
:
Any
=
[]
for
t
in
self
.
target_type
:
if
t
==
"category"
:
target
.
append
(
self
.
y
[
index
])
...
...
@@ -101,14 +108,14 @@ class Caltech101(VisionDataset):
return
img
,
target
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
# can be more robust and check hash of files
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
"101_ObjectCategories"
))
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
index
)
def
download
(
self
):
def
download
(
self
)
->
None
:
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
...
...
@@ -124,7 +131,7 @@ class Caltech101(VisionDataset):
filename
=
"101_Annotations.tar"
,
md5
=
"6f83eeb1f24d99cab4eb377263132c91"
)
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
return
"Target type: {target_type}"
.
format
(
**
self
.
__dict__
)
...
...
@@ -143,7 +150,13 @@ class Caltech256(VisionDataset):
downloaded again.
"""
def
__init__
(
self
,
root
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
:
str
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
Caltech256
,
self
).
__init__
(
os
.
path
.
join
(
root
,
'caltech256'
),
transform
=
transform
,
target_transform
=
target_transform
)
...
...
@@ -157,14 +170,14 @@ class Caltech256(VisionDataset):
' You can use download=True to download it'
)
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories"
)))
self
.
index
=
[]
self
.
index
:
List
[
int
]
=
[]
self
.
y
=
[]
for
(
i
,
c
)
in
enumerate
(
self
.
categories
):
n
=
len
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories"
,
c
)))
self
.
index
.
extend
(
range
(
1
,
n
+
1
))
self
.
y
.
extend
(
n
*
[
i
])
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
Args:
index (int): Index
...
...
@@ -187,14 +200,14 @@ class Caltech256(VisionDataset):
return
img
,
target
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
# can be more robust and check hash of files
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories"
))
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
index
)
def
download
(
self
):
def
download
(
self
)
->
None
:
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment