Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
52b80c48
Unverified
Commit
52b80c48
authored
Oct 27, 2022
by
F-G Fernandez
Committed by
GitHub
Oct 27, 2022
Browse files
style: Added typing to datasets/lfw (#6844)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
e0068d8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
11 deletions
+11
-11
torchvision/datasets/lfw.py
torchvision/datasets/lfw.py
+11
-11
No files found.
torchvision/datasets/lfw.py
View file @
52b80c48
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
...
...
@@ -38,7 +38,7 @@ class _LFW(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
os
.
path
.
join
(
root
,
self
.
base_folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
image_set
=
verify_str_arg
(
image_set
.
lower
(),
"image_set"
,
self
.
file_dict
.
keys
())
...
...
@@ -62,7 +62,7 @@ class _LFW(VisionDataset):
img
=
Image
.
open
(
f
)
return
img
.
convert
(
"RGB"
)
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
st1
=
check_integrity
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
),
self
.
md5
)
st2
=
check_integrity
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
),
self
.
checksums
[
self
.
labels_file
])
if
not
st1
or
not
st2
:
...
...
@@ -71,7 +71,7 @@ class _LFW(VisionDataset):
return
check_integrity
(
os
.
path
.
join
(
self
.
root
,
self
.
names
),
self
.
checksums
[
self
.
names
])
return
True
def
download
(
self
):
def
download
(
self
)
->
None
:
if
self
.
_check_integrity
():
print
(
"Files already downloaded and verified"
)
return
...
...
@@ -81,13 +81,13 @@ class _LFW(VisionDataset):
if
self
.
view
==
"people"
:
download_url
(
f
"
{
self
.
download_url_prefix
}{
self
.
names
}
"
,
self
.
root
)
def
_get_path
(
self
,
identity
,
no
)
:
def
_get_path
(
self
,
identity
:
str
,
no
:
Union
[
int
,
str
])
->
str
:
return
os
.
path
.
join
(
self
.
images_dir
,
identity
,
f
"
{
identity
}
_
{
int
(
no
):
04
d
}
.jpg"
)
def
extra_repr
(
self
)
->
str
:
return
f
"Alignment:
{
self
.
image_set
}
\n
Split:
{
self
.
split
}
"
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
...
...
@@ -119,13 +119,13 @@ class LFWPeople(_LFW):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
root
,
split
,
image_set
,
"people"
,
transform
,
target_transform
,
download
)
self
.
class_to_idx
=
self
.
_get_classes
()
self
.
data
,
self
.
targets
=
self
.
_get_people
()
def
_get_people
(
self
):
def
_get_people
(
self
)
->
Tuple
[
List
[
str
],
List
[
int
]]
:
data
,
targets
=
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
...
...
@@ -143,7 +143,7 @@ class LFWPeople(_LFW):
return
data
,
targets
def
_get_classes
(
self
):
def
_get_classes
(
self
)
->
Dict
[
str
,
int
]
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
names
))
as
f
:
lines
=
f
.
readlines
()
names
=
[
line
.
strip
().
split
()[
0
]
for
line
in
lines
]
...
...
@@ -201,12 +201,12 @@ class LFWPairs(_LFW):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
root
,
split
,
image_set
,
"pairs"
,
transform
,
target_transform
,
download
)
self
.
pair_names
,
self
.
data
,
self
.
targets
=
self
.
_get_pairs
(
self
.
images_dir
)
def
_get_pairs
(
self
,
images_dir
)
:
def
_get_pairs
(
self
,
images_dir
:
str
)
->
Tuple
[
List
[
Tuple
[
str
,
str
]],
List
[
Tuple
[
str
,
str
]],
List
[
int
]]
:
pair_names
,
data
,
targets
=
[],
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment