Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
40333c5a
Unverified
Commit
40333c5a
authored
Jul 31, 2020
by
Philip Meier
Committed by
GitHub
Jul 31, 2020
Browse files
celba (#2522)
parent
31245cb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
11 deletions
+19
-11
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+19
-11
No files found.
torchvision/datasets/celeba.py
View file @
40333c5a
...
...
@@ -2,6 +2,7 @@ from functools import partial
import
torch
import
os
import
PIL
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
,
Tuple
from
.vision
import
VisionDataset
from
.utils
import
download_file_from_google_drive
,
check_integrity
,
verify_str_arg
...
...
@@ -48,8 +49,15 @@ class CelebA(VisionDataset):
(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk"
,
"d32c9cbf5e040fd4025c592c306e6668"
,
"list_eval_partition.txt"
),
]
def
__init__
(
self
,
root
,
split
=
"train"
,
target_type
=
"attr"
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"attr"
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
import
pandas
super
(
CelebA
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
...
...
@@ -75,8 +83,8 @@ class CelebA(VisionDataset):
"test"
:
2
,
"all"
:
None
,
}
split
=
split_map
[
verify_str_arg
(
split
.
lower
(),
"split"
,
(
"train"
,
"valid"
,
"test"
,
"all"
))]
split
_
=
split_map
[
verify_str_arg
(
split
.
lower
(),
"split"
,
(
"train"
,
"valid"
,
"test"
,
"all"
))]
fn
=
partial
(
os
.
path
.
join
,
self
.
root
,
self
.
base_folder
)
splits
=
pandas
.
read_csv
(
fn
(
"list_eval_partition.txt"
),
delim_whitespace
=
True
,
header
=
None
,
index_col
=
0
)
...
...
@@ -85,7 +93,7 @@ class CelebA(VisionDataset):
landmarks_align
=
pandas
.
read_csv
(
fn
(
"list_landmarks_align_celeba.txt"
),
delim_whitespace
=
True
,
header
=
1
)
attr
=
pandas
.
read_csv
(
fn
(
"list_attr_celeba.txt"
),
delim_whitespace
=
True
,
header
=
1
)
mask
=
slice
(
None
)
if
split
is
None
else
(
splits
[
1
]
==
split
)
mask
=
slice
(
None
)
if
split
_
is
None
else
(
splits
[
1
]
==
split
_
)
self
.
filename
=
splits
[
mask
].
index
.
values
self
.
identity
=
torch
.
as_tensor
(
identity
[
mask
].
values
)
...
...
@@ -95,7 +103,7 @@ class CelebA(VisionDataset):
self
.
attr
=
(
self
.
attr
+
1
)
//
2
# map from {-1, 1} to {0, 1}
self
.
attr_names
=
list
(
attr
.
columns
)
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
for
(
_
,
md5
,
filename
)
in
self
.
file_list
:
fpath
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
filename
)
_
,
ext
=
os
.
path
.
splitext
(
filename
)
...
...
@@ -107,7 +115,7 @@ class CelebA(VisionDataset):
# Should check a hash of the images
return
os
.
path
.
isdir
(
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
"img_align_celeba"
))
def
download
(
self
):
def
download
(
self
)
->
None
:
import
zipfile
if
self
.
_check_integrity
():
...
...
@@ -120,10 +128,10 @@ class CelebA(VisionDataset):
with
zipfile
.
ZipFile
(
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
"img_align_celeba.zip"
),
"r"
)
as
f
:
f
.
extractall
(
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
))
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
X
=
PIL
.
Image
.
open
(
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
"img_align_celeba"
,
self
.
filename
[
index
]))
target
=
[]
target
:
Any
=
[]
for
t
in
self
.
target_type
:
if
t
==
"attr"
:
target
.
append
(
self
.
attr
[
index
,
:])
...
...
@@ -150,9 +158,9 @@ class CelebA(VisionDataset):
return
X
,
target
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
attr
)
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
lines
=
[
"Target type: {target_type}"
,
"Split: {split}"
]
return
'
\n
'
.
join
(
lines
).
format
(
**
self
.
__dict__
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment