Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
91 additions
and
100 deletions
+91
-100
test/test_transforms_tensor.py
test/test_transforms_tensor.py
+3
-3
test/test_transforms_video.py
test/test_transforms_video.py
+1
-1
torchvision/__init__.py
torchvision/__init__.py
+2
-2
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+7
-11
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+3
-3
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+5
-6
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+11
-11
torchvision/datasets/fakedata.py
torchvision/datasets/fakedata.py
+2
-4
torchvision/datasets/flickr.py
torchvision/datasets/flickr.py
+3
-3
torchvision/datasets/folder.py
torchvision/datasets/folder.py
+3
-3
torchvision/datasets/hmdb51.py
torchvision/datasets/hmdb51.py
+3
-3
torchvision/datasets/imagenet.py
torchvision/datasets/imagenet.py
+4
-4
torchvision/datasets/inaturalist.py
torchvision/datasets/inaturalist.py
+2
-4
torchvision/datasets/kinetics.py
torchvision/datasets/kinetics.py
+2
-2
torchvision/datasets/lfw.py
torchvision/datasets/lfw.py
+8
-10
torchvision/datasets/lsun.py
torchvision/datasets/lsun.py
+4
-4
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+13
-12
torchvision/datasets/omniglot.py
torchvision/datasets/omniglot.py
+3
-3
torchvision/datasets/phototour.py
torchvision/datasets/phototour.py
+10
-9
torchvision/datasets/places365.py
torchvision/datasets/places365.py
+2
-2
No files found.
test/test_transforms_tensor.py
View file @
d367a01a
...
...
@@ -371,7 +371,7 @@ def test_x_crop_save(method, tmpdir):
]
)
scripted_fn
=
torch
.
jit
.
script
(
fn
)
scripted_fn
.
save
(
os
.
path
.
join
(
tmpdir
,
"t_op_list_{
}.pt"
.
format
(
method
)
))
scripted_fn
.
save
(
os
.
path
.
join
(
tmpdir
,
f
"t_op_list_
{
method
}
.pt"
))
class
TestResize
:
...
...
@@ -816,7 +816,7 @@ def test_compose(device):
transformed_tensor
=
transforms
(
tensor
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{
}"
.
format
(
transforms
)
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
f
"
{
transforms
}
"
)
t
=
T
.
Compose
(
[
...
...
@@ -854,7 +854,7 @@ def test_random_apply(device):
transformed_tensor
=
transforms
(
tensor
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{
}"
.
format
(
transforms
)
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
f
"
{
transforms
}
"
)
if
device
==
"cpu"
:
# Can't check this twice, otherwise
...
...
test/test_transforms_video.py
View file @
d367a01a
...
...
@@ -163,7 +163,7 @@ class TestVideoTransforms:
@
pytest
.
mark
.
parametrize
(
"p"
,
(
0
,
1
))
def
test_random_horizontal_flip_video
(
self
,
p
):
clip
=
torch
.
rand
((
3
,
4
,
112
,
112
),
dtype
=
torch
.
float
)
hclip
=
clip
.
flip
(
(
-
1
)
)
hclip
=
clip
.
flip
(
-
1
)
out
=
transforms
.
RandomHorizontalFlipVideo
(
p
=
p
)(
clip
)
if
p
==
0
:
...
...
torchvision/__init__.py
View file @
d367a01a
...
...
@@ -43,7 +43,7 @@ def set_image_backend(backend):
"""
global
_image_backend
if
backend
not
in
[
"PIL"
,
"accimage"
]:
raise
ValueError
(
"Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.
format
(
backend
)
)
raise
ValueError
(
f
"Invalid backend '
{
backend
}
'. Options are 'PIL' and 'accimage'"
)
_image_backend
=
backend
...
...
@@ -74,7 +74,7 @@ def set_video_backend(backend):
if
backend
not
in
[
"pyav"
,
"video_reader"
]:
raise
ValueError
(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'"
%
backend
)
if
backend
==
"video_reader"
and
not
io
.
_HAS_VIDEO_OPT
:
message
=
"video_reader video backend is not available.
"
"
Please compile torchvision from source and try again"
message
=
"video_reader video backend is not available. Please compile torchvision from source and try again"
warnings
.
warn
(
message
)
else
:
_video_backend
=
backend
...
...
torchvision/datasets/caltech.py
View file @
d367a01a
...
...
@@ -40,9 +40,7 @@ class Caltech101(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
Caltech101
,
self
).
__init__
(
os
.
path
.
join
(
root
,
"caltech101"
),
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
os
.
path
.
join
(
root
,
"caltech101"
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
if
isinstance
(
target_type
,
str
):
target_type
=
[
target_type
]
...
...
@@ -52,7 +50,7 @@ class Caltech101(VisionDataset):
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"101_ObjectCategories"
)))
self
.
categories
.
remove
(
"BACKGROUND_Google"
)
# this is not a real class
...
...
@@ -90,7 +88,7 @@ class Caltech101(VisionDataset):
self
.
root
,
"101_ObjectCategories"
,
self
.
categories
[
self
.
y
[
index
]],
"image_{
:04d}.jpg"
.
format
(
self
.
index
[
index
]
)
,
f
"image_
{
self
.
index
[
index
]
:
04
d
}
.jpg"
,
)
)
...
...
@@ -104,7 +102,7 @@ class Caltech101(VisionDataset):
self
.
root
,
"Annotations"
,
self
.
annotation_categories
[
self
.
y
[
index
]],
"annotation_{
:04d}.mat"
.
format
(
self
.
index
[
index
]
)
,
f
"annotation_
{
self
.
index
[
index
]
:
04
d
}
.mat"
,
)
)
target
.
append
(
data
[
"obj_contour"
])
...
...
@@ -167,16 +165,14 @@ class Caltech256(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
Caltech256
,
self
).
__init__
(
os
.
path
.
join
(
root
,
"caltech256"
),
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
os
.
path
.
join
(
root
,
"caltech256"
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories"
)))
self
.
index
:
List
[
int
]
=
[]
...
...
@@ -205,7 +201,7 @@ class Caltech256(VisionDataset):
self
.
root
,
"256_ObjectCategories"
,
self
.
categories
[
self
.
y
[
index
]],
"{
:03d}_{:04d}.jpg"
.
format
(
self
.
y
[
index
]
+
1
,
self
.
index
[
index
]
)
,
f
"
{
self
.
y
[
index
]
+
1
:
03
d
}
_
{
self
.
index
[
index
]
:
04
d
}
.jpg"
,
)
)
...
...
torchvision/datasets/celeba.py
View file @
d367a01a
...
...
@@ -66,7 +66,7 @@ class CelebA(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
CelebA
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
split
=
split
if
isinstance
(
target_type
,
list
):
self
.
target_type
=
target_type
...
...
@@ -80,7 +80,7 @@ class CelebA(VisionDataset):
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
split_map
=
{
"train"
:
0
,
...
...
@@ -166,7 +166,7 @@ class CelebA(VisionDataset):
target
.
append
(
self
.
landmarks_align
[
index
,
:])
else
:
# TODO: refactor with utils.verify_str_arg
raise
ValueError
(
'Target type "{}" is not recognized.'
.
format
(
t
)
)
raise
ValueError
(
f
'Target type "
{
t
}
" is not recognized.'
)
if
self
.
transform
is
not
None
:
X
=
self
.
transform
(
X
)
...
...
torchvision/datasets/cifar.py
View file @
d367a01a
...
...
@@ -58,7 +58,7 @@ class CIFAR10(VisionDataset):
download
:
bool
=
False
,
)
->
None
:
super
(
CIFAR10
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
train
=
train
# training set or test set
...
...
@@ -66,7 +66,7 @@ class CIFAR10(VisionDataset):
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
if
self
.
train
:
downloaded_list
=
self
.
train_list
...
...
@@ -95,9 +95,7 @@ class CIFAR10(VisionDataset):
def
_load_meta
(
self
)
->
None
:
path
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
meta
[
"filename"
])
if
not
check_integrity
(
path
,
self
.
meta
[
"md5"
]):
raise
RuntimeError
(
"Dataset metadata file not found or corrupted."
+
" You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset metadata file not found or corrupted. You can use download=True to download it"
)
with
open
(
path
,
"rb"
)
as
infile
:
data
=
pickle
.
load
(
infile
,
encoding
=
"latin1"
)
self
.
classes
=
data
[
self
.
meta
[
"key"
]]
...
...
@@ -144,7 +142,8 @@ class CIFAR10(VisionDataset):
download_and_extract_archive
(
self
.
url
,
self
.
root
,
filename
=
self
.
filename
,
md5
=
self
.
tgz_md5
)
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
split
=
"Train"
if
self
.
train
is
True
else
"Test"
return
f
"Split:
{
split
}
"
class
CIFAR100
(
CIFAR10
):
...
...
torchvision/datasets/cityscapes.py
View file @
d367a01a
...
...
@@ -111,7 +111,7 @@ class Cityscapes(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
super
().
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
"gtFine"
if
mode
==
"fine"
else
"gtCoarse"
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit"
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
mode
,
split
)
...
...
@@ -125,7 +125,7 @@ class Cityscapes(VisionDataset):
valid_modes
=
(
"train"
,
"test"
,
"val"
)
else
:
valid_modes
=
(
"train"
,
"train_extra"
,
"val"
)
msg
=
"Unknown value '{}' for argument split if mode is '{}'.
"
"
Valid values are {{{}}}."
msg
=
"Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
msg
=
msg
.
format
(
split
,
mode
,
iterable_to_str
(
valid_modes
))
verify_str_arg
(
split
,
"split"
,
valid_modes
,
msg
)
...
...
@@ -139,14 +139,14 @@ class Cityscapes(VisionDataset):
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
if
split
==
"train_extra"
:
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit
{}"
.
format
(
"
_trainextra.zip"
)
)
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit_trainextra.zip"
)
else
:
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit
{}"
.
format
(
"
_trainvaltest.zip"
)
)
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit_trainvaltest.zip"
)
if
self
.
mode
==
"gtFine"
:
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"{
}{}"
.
format
(
self
.
mode
,
"
_trainvaltest.zip"
)
)
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
f
"
{
self
.
mode
}
_trainvaltest.zip"
)
elif
self
.
mode
==
"gtCoarse"
:
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"{
}{}"
.
format
(
self
.
mode
,
"
.zip"
)
)
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
f
"
{
self
.
mode
}
.zip"
)
if
os
.
path
.
isfile
(
image_dir_zip
)
and
os
.
path
.
isfile
(
target_dir_zip
):
extract_archive
(
from_path
=
image_dir_zip
,
to_path
=
self
.
root
)
...
...
@@ -206,16 +206,16 @@ class Cityscapes(VisionDataset):
return
"
\n
"
.
join
(
lines
).
format
(
**
self
.
__dict__
)
def
_load_json
(
self
,
path
:
str
)
->
Dict
[
str
,
Any
]:
with
open
(
path
,
"r"
)
as
file
:
with
open
(
path
)
as
file
:
data
=
json
.
load
(
file
)
return
data
def
_get_target_suffix
(
self
,
mode
:
str
,
target_type
:
str
)
->
str
:
if
target_type
==
"instance"
:
return
"{}_instanceIds.png"
.
format
(
mode
)
return
f
"
{
mode
}
_instanceIds.png"
elif
target_type
==
"semantic"
:
return
"{}_labelIds.png"
.
format
(
mode
)
return
f
"
{
mode
}
_labelIds.png"
elif
target_type
==
"color"
:
return
"{}_color.png"
.
format
(
mode
)
return
f
"
{
mode
}
_color.png"
else
:
return
"{}_polygons.json"
.
format
(
mode
)
return
f
"
{
mode
}
_polygons.json"
torchvision/datasets/fakedata.py
View file @
d367a01a
...
...
@@ -31,9 +31,7 @@ class FakeData(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
random_offset
:
int
=
0
,
)
->
None
:
super
(
FakeData
,
self
).
__init__
(
None
,
transform
=
transform
,
target_transform
=
target_transform
# type: ignore[arg-type]
)
super
().
__init__
(
None
,
transform
=
transform
,
target_transform
=
target_transform
)
# type: ignore[arg-type]
self
.
size
=
size
self
.
num_classes
=
num_classes
self
.
image_size
=
image_size
...
...
@@ -49,7 +47,7 @@ class FakeData(VisionDataset):
"""
# create random image that is consistent with the index id
if
index
>=
len
(
self
):
raise
IndexError
(
"{
} index out of range"
.
format
(
self
.
__class__
.
__name__
)
)
raise
IndexError
(
f
"
{
self
.
__class__
.
__name__
}
index out of range"
)
rng_state
=
torch
.
get_rng_state
()
torch
.
manual_seed
(
index
+
self
.
random_offset
)
img
=
torch
.
randn
(
*
self
.
image_size
)
...
...
torchvision/datasets/flickr.py
View file @
d367a01a
...
...
@@ -13,7 +13,7 @@ class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def
__init__
(
self
,
root
:
str
)
->
None
:
super
(
Flickr8kParser
,
self
).
__init__
()
super
().
__init__
()
self
.
root
=
root
...
...
@@ -71,7 +71,7 @@ class Flickr8k(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Flickr8k
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
ann_file
=
os
.
path
.
expanduser
(
ann_file
)
# Read annotations and store in a dict
...
...
@@ -127,7 +127,7 @@ class Flickr30k(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Flickr30k
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
ann_file
=
os
.
path
.
expanduser
(
ann_file
)
# Read annotations and store in a dict
...
...
torchvision/datasets/folder.py
View file @
d367a01a
...
...
@@ -140,7 +140,7 @@ class DatasetFolder(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
None
:
super
(
DatasetFolder
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
classes
,
class_to_idx
=
self
.
find_classes
(
self
.
root
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
,
extensions
,
is_valid_file
)
...
...
@@ -254,7 +254,7 @@ def accimage_loader(path: str) -> Any:
try
:
return
accimage
.
Image
(
path
)
except
I
OError
:
except
O
S
Error
:
# Potentially a decoding problem, fall back to PIL.Image
return
pil_loader
(
path
)
...
...
@@ -306,7 +306,7 @@ class ImageFolder(DatasetFolder):
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
super
(
ImageFolder
,
self
).
__init__
(
super
().
__init__
(
root
,
loader
,
IMG_EXTENSIONS
if
is_valid_file
is
None
else
None
,
...
...
torchvision/datasets/hmdb51.py
View file @
d367a01a
...
...
@@ -72,9 +72,9 @@ class HMDB51(VisionDataset):
_video_min_dimension
:
int
=
0
,
_audio_samples
:
int
=
0
,
)
->
None
:
super
(
HMDB51
,
self
).
__init__
(
root
)
super
().
__init__
(
root
)
if
fold
not
in
(
1
,
2
,
3
):
raise
ValueError
(
"fold should be between 1 and 3, got {
}"
.
format
(
fold
)
)
raise
ValueError
(
f
"fold should be between 1 and 3, got
{
fold
}
"
)
extensions
=
(
"avi"
,)
self
.
classes
,
class_to_idx
=
find_classes
(
self
.
root
)
...
...
@@ -113,7 +113,7 @@ class HMDB51(VisionDataset):
def
_select_fold
(
self
,
video_list
:
List
[
str
],
annotations_dir
:
str
,
fold
:
int
,
train
:
bool
)
->
List
[
int
]:
target_tag
=
self
.
TRAIN_TAG
if
train
else
self
.
TEST_TAG
split_pattern_name
=
"*test_split{}.txt"
.
format
(
fold
)
split_pattern_name
=
f
"*test_split
{
fold
}
.txt"
split_pattern_path
=
os
.
path
.
join
(
annotations_dir
,
split_pattern_name
)
annotation_paths
=
glob
.
glob
(
split_pattern_path
)
selected_files
=
set
()
...
...
torchvision/datasets/imagenet.py
View file @
d367a01a
...
...
@@ -49,7 +49,7 @@ class ImageNet(ImageFolder):
)
raise
RuntimeError
(
msg
)
elif
download
is
False
:
msg
=
"The use of the download flag is deprecated, since the dataset
"
"
is no longer publicly accessible."
msg
=
"The use of the download flag is deprecated, since the dataset is no longer publicly accessible."
warnings
.
warn
(
msg
,
RuntimeWarning
)
root
=
self
.
root
=
os
.
path
.
expanduser
(
root
)
...
...
@@ -58,7 +58,7 @@ class ImageNet(ImageFolder):
self
.
parse_archives
()
wnid_to_classes
=
load_meta_file
(
self
.
root
)[
0
]
super
(
ImageNet
,
self
).
__init__
(
self
.
split_folder
,
**
kwargs
)
super
().
__init__
(
self
.
split_folder
,
**
kwargs
)
self
.
root
=
root
self
.
wnids
=
self
.
classes
...
...
@@ -132,7 +132,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def
parse_val_groundtruth_txt
(
devkit_root
:
str
)
->
List
[
int
]:
file
=
os
.
path
.
join
(
devkit_root
,
"data"
,
"ILSVRC2012_validation_ground_truth.txt"
)
with
open
(
file
,
"r"
)
as
txtfh
:
with
open
(
file
)
as
txtfh
:
val_idcs
=
txtfh
.
readlines
()
return
[
int
(
val_idx
)
for
val_idx
in
val_idcs
]
...
...
@@ -215,7 +215,7 @@ def parse_val_archive(
val_root
=
os
.
path
.
join
(
root
,
folder
)
extract_archive
(
os
.
path
.
join
(
root
,
file
),
val_root
)
images
=
sorted
(
[
os
.
path
.
join
(
val_root
,
image
)
for
image
in
os
.
listdir
(
val_root
)
]
)
images
=
sorted
(
os
.
path
.
join
(
val_root
,
image
)
for
image
in
os
.
listdir
(
val_root
))
for
wnid
in
set
(
wnids
):
os
.
mkdir
(
os
.
path
.
join
(
val_root
,
wnid
))
...
...
torchvision/datasets/inaturalist.py
View file @
d367a01a
...
...
@@ -74,16 +74,14 @@ class INaturalist(VisionDataset):
)
->
None
:
self
.
version
=
verify_str_arg
(
version
,
"version"
,
DATASET_URLS
.
keys
())
super
(
INaturalist
,
self
).
__init__
(
os
.
path
.
join
(
root
,
version
),
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
os
.
path
.
join
(
root
,
version
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
makedirs
(
root
,
exist_ok
=
True
)
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
all_categories
:
List
[
str
]
=
[]
...
...
torchvision/datasets/kinetics.py
View file @
d367a01a
...
...
@@ -175,7 +175,7 @@ class Kinetics(VisionDataset):
split_url_filepath
=
path
.
join
(
file_list_path
,
path
.
basename
(
split_url
))
if
not
check_integrity
(
split_url_filepath
):
download_url
(
split_url
,
file_list_path
)
list_video_urls
=
open
(
split_url_filepath
,
"r"
)
list_video_urls
=
open
(
split_url_filepath
)
if
self
.
num_download_workers
==
1
:
for
line
in
list_video_urls
.
readlines
():
...
...
@@ -309,7 +309,7 @@ class Kinetics400(Kinetics):
"Kinetics400. Please use Kinetics instead."
)
super
(
Kinetics400
,
self
).
__init__
(
super
().
__init__
(
root
=
root
,
frames_per_clip
=
frames_per_clip
,
_legacy
=
True
,
...
...
torchvision/datasets/lfw.py
View file @
d367a01a
...
...
@@ -39,9 +39,7 @@ class _LFW(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
super
(
_LFW
,
self
).
__init__
(
os
.
path
.
join
(
root
,
self
.
base_folder
),
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
os
.
path
.
join
(
root
,
self
.
base_folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
image_set
=
verify_str_arg
(
image_set
.
lower
(),
"image_set"
,
self
.
file_dict
.
keys
())
images_dir
,
self
.
filename
,
self
.
md5
=
self
.
file_dict
[
self
.
image_set
]
...
...
@@ -55,7 +53,7 @@ class _LFW(VisionDataset):
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
images_dir
)
...
...
@@ -122,14 +120,14 @@ class LFWPeople(_LFW):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
super
(
LFWPeople
,
self
).
__init__
(
root
,
split
,
image_set
,
"people"
,
transform
,
target_transform
,
download
)
super
().
__init__
(
root
,
split
,
image_set
,
"people"
,
transform
,
target_transform
,
download
)
self
.
class_to_idx
=
self
.
_get_classes
()
self
.
data
,
self
.
targets
=
self
.
_get_people
()
def
_get_people
(
self
):
data
,
targets
=
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
n_folds
,
s
=
(
int
(
lines
[
0
]),
1
)
if
self
.
split
==
"10fold"
else
(
1
,
0
)
...
...
@@ -146,7 +144,7 @@ class LFWPeople(_LFW):
return
data
,
targets
def
_get_classes
(
self
):
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
names
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
names
))
as
f
:
lines
=
f
.
readlines
()
names
=
[
line
.
strip
().
split
()[
0
]
for
line
in
lines
]
class_to_idx
=
{
name
:
i
for
i
,
name
in
enumerate
(
names
)}
...
...
@@ -172,7 +170,7 @@ class LFWPeople(_LFW):
return
img
,
target
def
extra_repr
(
self
)
->
str
:
return
super
().
extra_repr
()
+
"
\n
Classes (identities): {
}"
.
format
(
len
(
self
.
class_to_idx
)
)
return
super
().
extra_repr
()
+
f
"
\n
Classes (identities):
{
len
(
self
.
class_to_idx
)
}
"
class
LFWPairs
(
_LFW
):
...
...
@@ -204,13 +202,13 @@ class LFWPairs(_LFW):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
super
(
LFWPairs
,
self
).
__init__
(
root
,
split
,
image_set
,
"pairs"
,
transform
,
target_transform
,
download
)
super
().
__init__
(
root
,
split
,
image_set
,
"pairs"
,
transform
,
target_transform
,
download
)
self
.
pair_names
,
self
.
data
,
self
.
targets
=
self
.
_get_pairs
(
self
.
images_dir
)
def
_get_pairs
(
self
,
images_dir
):
pair_names
,
data
,
targets
=
[],
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
if
self
.
split
==
"10fold"
:
n_folds
,
n_pairs
=
lines
[
0
].
split
(
"
\t
"
)
...
...
torchvision/datasets/lsun.py
View file @
d367a01a
...
...
@@ -18,7 +18,7 @@ class LSUNClass(VisionDataset):
)
->
None
:
import
lmdb
super
(
LSUNClass
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
env
=
lmdb
.
open
(
root
,
max_readers
=
1
,
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
meminit
=
False
)
with
self
.
env
.
begin
(
write
=
False
)
as
txn
:
...
...
@@ -77,7 +77,7 @@ class LSUN(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
LSUN
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
classes
=
self
.
_verify_classes
(
classes
)
# for each class, create an LSUNClassDataset
...
...
@@ -117,11 +117,11 @@ class LSUN(VisionDataset):
classes
=
[
c
+
"_"
+
classes
for
c
in
categories
]
except
ValueError
:
if
not
isinstance
(
classes
,
Iterable
):
msg
=
"Expected type str or Iterable for argument classes,
"
"
but got type {}."
msg
=
"Expected type str or Iterable for argument classes, but got type {}."
raise
ValueError
(
msg
.
format
(
type
(
classes
)))
classes
=
list
(
classes
)
msg_fmtstr_type
=
"Expected type str for elements in argument classes,
"
"
but got type {}."
msg_fmtstr_type
=
"Expected type str for elements in argument classes, but got type {}."
for
c
in
classes
:
verify_str_arg
(
c
,
custom_msg
=
msg_fmtstr_type
.
format
(
type
(
c
)))
c_short
=
c
.
split
(
"_"
)
...
...
torchvision/datasets/mnist.py
View file @
d367a01a
...
...
@@ -88,7 +88,7 @@ class MNIST(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
MNIST
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
train
=
train
# training set or test set
if
self
.
_check_legacy_exist
():
...
...
@@ -99,7 +99,7 @@ class MNIST(VisionDataset):
self
.
download
()
if
not
self
.
_check_exists
():
raise
RuntimeError
(
"Dataset not found.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found. You can use download=True to download it"
)
self
.
data
,
self
.
targets
=
self
.
_load_data
()
...
...
@@ -181,21 +181,22 @@ class MNIST(VisionDataset):
# download files
for
filename
,
md5
in
self
.
resources
:
for
mirror
in
self
.
mirrors
:
url
=
"{
}{}"
.
format
(
mirror
,
filename
)
url
=
f
"
{
mirror
}{
filename
}
"
try
:
print
(
"Downloading {
}"
.
format
(
url
)
)
print
(
f
"Downloading
{
url
}
"
)
download_and_extract_archive
(
url
,
download_root
=
self
.
raw_folder
,
filename
=
filename
,
md5
=
md5
)
except
URLError
as
error
:
print
(
"Failed to download (trying next):
\n
{
}"
.
format
(
error
)
)
print
(
f
"Failed to download (trying next):
\n
{
error
}
"
)
continue
finally
:
print
()
break
else
:
raise
RuntimeError
(
"Error downloading {
}"
.
format
(
filename
)
)
raise
RuntimeError
(
f
"Error downloading
{
filename
}
"
)
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
split
=
"Train"
if
self
.
train
is
True
else
"Test"
return
f
"Split:
{
split
}
"
class
FashionMNIST
(
MNIST
):
...
...
@@ -293,16 +294,16 @@ class EMNIST(MNIST):
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
training_file
=
self
.
_training_file
(
split
)
self
.
test_file
=
self
.
_test_file
(
split
)
super
(
EMNIST
,
self
).
__init__
(
root
,
**
kwargs
)
super
().
__init__
(
root
,
**
kwargs
)
self
.
classes
=
self
.
classes_split_dict
[
self
.
split
]
@
staticmethod
def
_training_file
(
split
)
->
str
:
return
"training_{}.pt"
.
format
(
split
)
return
f
"training_
{
split
}
.pt"
@
staticmethod
def
_test_file
(
split
)
->
str
:
return
"test_{}.pt"
.
format
(
split
)
return
f
"test_
{
split
}
.pt"
@
property
def
_file_prefix
(
self
)
->
str
:
...
...
@@ -424,7 +425,7 @@ class QMNIST(MNIST):
self
.
data_file
=
what
+
".pt"
self
.
training_file
=
self
.
data_file
self
.
test_file
=
self
.
data_file
super
(
QMNIST
,
self
).
__init__
(
root
,
train
,
**
kwargs
)
super
().
__init__
(
root
,
train
,
**
kwargs
)
@
property
def
images_file
(
self
)
->
str
:
...
...
@@ -482,7 +483,7 @@ class QMNIST(MNIST):
return
img
,
target
def
extra_repr
(
self
)
->
str
:
return
"Split: {
}"
.
format
(
self
.
what
)
return
f
"Split:
{
self
.
what
}
"
def
get_int
(
b
:
bytes
)
->
int
:
...
...
torchvision/datasets/omniglot.py
View file @
d367a01a
...
...
@@ -39,19 +39,19 @@ class Omniglot(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
Omniglot
,
self
).
__init__
(
join
(
root
,
self
.
folder
),
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
join
(
root
,
self
.
folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
background
=
background
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
target_folder
=
join
(
self
.
root
,
self
.
_get_target_folder
())
self
.
_alphabets
=
list_dir
(
self
.
target_folder
)
self
.
_characters
:
List
[
str
]
=
sum
(
[
[
join
(
a
,
c
)
for
c
in
list_dir
(
join
(
self
.
target_folder
,
a
))]
for
a
in
self
.
_alphabets
]
,
[]
(
[
join
(
a
,
c
)
for
c
in
list_dir
(
join
(
self
.
target_folder
,
a
))]
for
a
in
self
.
_alphabets
)
,
[]
)
self
.
_character_images
=
[
[(
image
,
idx
)
for
image
in
list_files
(
join
(
self
.
target_folder
,
character
),
".png"
)]
...
...
torchvision/datasets/phototour.py
View file @
d367a01a
...
...
@@ -89,11 +89,11 @@ class PhotoTour(VisionDataset):
def
__init__
(
self
,
root
:
str
,
name
:
str
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
)
->
None
:
super
(
PhotoTour
,
self
).
__init__
(
root
,
transform
=
transform
)
super
().
__init__
(
root
,
transform
=
transform
)
self
.
name
=
name
self
.
data_dir
=
os
.
path
.
join
(
self
.
root
,
name
)
self
.
data_down
=
os
.
path
.
join
(
self
.
root
,
"{}.zip"
.
format
(
name
)
)
self
.
data_file
=
os
.
path
.
join
(
self
.
root
,
"{}.pt"
.
format
(
name
)
)
self
.
data_down
=
os
.
path
.
join
(
self
.
root
,
f
"
{
name
}
.zip"
)
self
.
data_file
=
os
.
path
.
join
(
self
.
root
,
f
"
{
name
}
.pt"
)
self
.
train
=
train
self
.
mean
=
self
.
means
[
name
]
...
...
@@ -139,7 +139,7 @@ class PhotoTour(VisionDataset):
def
download
(
self
)
->
None
:
if
self
.
_check_datafile_exists
():
print
(
"# Found cached data {
}"
.
format
(
self
.
data_file
)
)
print
(
f
"# Found cached data
{
self
.
data_file
}
"
)
return
if
not
self
.
_check_downloaded
():
...
...
@@ -151,7 +151,7 @@ class PhotoTour(VisionDataset):
download_url
(
url
,
self
.
root
,
filename
,
md5
)
print
(
"# Extracting data {
}
\n
"
.
format
(
self
.
data_down
)
)
print
(
f
"# Extracting data
{
self
.
data_down
}
\n
"
)
import
zipfile
...
...
@@ -162,7 +162,7 @@ class PhotoTour(VisionDataset):
def
cache
(
self
)
->
None
:
# process and save as torch files
print
(
"# Caching data {
}"
.
format
(
self
.
data_file
)
)
print
(
f
"# Caching data
{
self
.
data_file
}
"
)
dataset
=
(
read_image_file
(
self
.
data_dir
,
self
.
image_ext
,
self
.
lens
[
self
.
name
]),
...
...
@@ -174,7 +174,8 @@ class PhotoTour(VisionDataset):
torch
.
save
(
dataset
,
f
)
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
split
=
"Train"
if
self
.
train
is
True
else
"Test"
return
f
"Split:
{
split
}
"
def
read_image_file
(
data_dir
:
str
,
image_ext
:
str
,
n
:
int
)
->
torch
.
Tensor
:
...
...
@@ -209,7 +210,7 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point.
"""
with
open
(
os
.
path
.
join
(
data_dir
,
info_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
data_dir
,
info_file
))
as
f
:
labels
=
[
int
(
line
.
split
()[
0
])
for
line
in
f
]
return
torch
.
LongTensor
(
labels
)
...
...
@@ -220,7 +221,7 @@ def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
Matches are represented with a 1, non matches with a 0.
"""
matches
=
[]
with
open
(
os
.
path
.
join
(
data_dir
,
matches_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
data_dir
,
matches_file
))
as
f
:
for
line
in
f
:
line_split
=
line
.
split
()
matches
.
append
([
int
(
line_split
[
0
]),
int
(
line_split
[
3
]),
int
(
line_split
[
1
]
==
line_split
[
4
])])
...
...
torchvision/datasets/places365.py
View file @
d367a01a
...
...
@@ -117,7 +117,7 @@ class Places365(VisionDataset):
if
not
self
.
_check_integrity
(
file
,
md5
,
download
):
self
.
download_devkit
()
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
)
as
fh
:
class_to_idx
=
dict
(
process
(
line
)
for
line
in
fh
)
return
sorted
(
class_to_idx
.
keys
()),
class_to_idx
...
...
@@ -132,7 +132,7 @@ class Places365(VisionDataset):
if
not
self
.
_check_integrity
(
file
,
md5
,
download
):
self
.
download_devkit
()
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
)
as
fh
:
images
=
[
process
(
line
)
for
line
in
fh
]
_
,
targets
=
zip
(
*
images
)
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment