Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
91 additions
and
100 deletions
+91
-100
test/test_transforms_tensor.py
test/test_transforms_tensor.py
+3
-3
test/test_transforms_video.py
test/test_transforms_video.py
+1
-1
torchvision/__init__.py
torchvision/__init__.py
+2
-2
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+7
-11
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+3
-3
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+5
-6
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+11
-11
torchvision/datasets/fakedata.py
torchvision/datasets/fakedata.py
+2
-4
torchvision/datasets/flickr.py
torchvision/datasets/flickr.py
+3
-3
torchvision/datasets/folder.py
torchvision/datasets/folder.py
+3
-3
torchvision/datasets/hmdb51.py
torchvision/datasets/hmdb51.py
+3
-3
torchvision/datasets/imagenet.py
torchvision/datasets/imagenet.py
+4
-4
torchvision/datasets/inaturalist.py
torchvision/datasets/inaturalist.py
+2
-4
torchvision/datasets/kinetics.py
torchvision/datasets/kinetics.py
+2
-2
torchvision/datasets/lfw.py
torchvision/datasets/lfw.py
+8
-10
torchvision/datasets/lsun.py
torchvision/datasets/lsun.py
+4
-4
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+13
-12
torchvision/datasets/omniglot.py
torchvision/datasets/omniglot.py
+3
-3
torchvision/datasets/phototour.py
torchvision/datasets/phototour.py
+10
-9
torchvision/datasets/places365.py
torchvision/datasets/places365.py
+2
-2
No files found.
test/test_transforms_tensor.py
View file @
d367a01a
...
@@ -371,7 +371,7 @@ def test_x_crop_save(method, tmpdir):
...
@@ -371,7 +371,7 @@ def test_x_crop_save(method, tmpdir):
]
]
)
)
scripted_fn
=
torch
.
jit
.
script
(
fn
)
scripted_fn
=
torch
.
jit
.
script
(
fn
)
scripted_fn
.
save
(
os
.
path
.
join
(
tmpdir
,
"t_op_list_{
}.pt"
.
format
(
method
)
))
scripted_fn
.
save
(
os
.
path
.
join
(
tmpdir
,
f
"t_op_list_
{
method
}
.pt"
))
class
TestResize
:
class
TestResize
:
...
@@ -816,7 +816,7 @@ def test_compose(device):
...
@@ -816,7 +816,7 @@ def test_compose(device):
transformed_tensor
=
transforms
(
tensor
)
transformed_tensor
=
transforms
(
tensor
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{
}"
.
format
(
transforms
)
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
f
"
{
transforms
}
"
)
t
=
T
.
Compose
(
t
=
T
.
Compose
(
[
[
...
@@ -854,7 +854,7 @@ def test_random_apply(device):
...
@@ -854,7 +854,7 @@ def test_random_apply(device):
transformed_tensor
=
transforms
(
tensor
)
transformed_tensor
=
transforms
(
tensor
)
torch
.
manual_seed
(
12
)
torch
.
manual_seed
(
12
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
transformed_tensor_script
=
scripted_fn
(
tensor
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
"{
}"
.
format
(
transforms
)
)
assert_equal
(
transformed_tensor
,
transformed_tensor_script
,
msg
=
f
"
{
transforms
}
"
)
if
device
==
"cpu"
:
if
device
==
"cpu"
:
# Can't check this twice, otherwise
# Can't check this twice, otherwise
...
...
test/test_transforms_video.py
View file @
d367a01a
...
@@ -163,7 +163,7 @@ class TestVideoTransforms:
...
@@ -163,7 +163,7 @@ class TestVideoTransforms:
@
pytest
.
mark
.
parametrize
(
"p"
,
(
0
,
1
))
@
pytest
.
mark
.
parametrize
(
"p"
,
(
0
,
1
))
def
test_random_horizontal_flip_video
(
self
,
p
):
def
test_random_horizontal_flip_video
(
self
,
p
):
clip
=
torch
.
rand
((
3
,
4
,
112
,
112
),
dtype
=
torch
.
float
)
clip
=
torch
.
rand
((
3
,
4
,
112
,
112
),
dtype
=
torch
.
float
)
hclip
=
clip
.
flip
(
(
-
1
)
)
hclip
=
clip
.
flip
(
-
1
)
out
=
transforms
.
RandomHorizontalFlipVideo
(
p
=
p
)(
clip
)
out
=
transforms
.
RandomHorizontalFlipVideo
(
p
=
p
)(
clip
)
if
p
==
0
:
if
p
==
0
:
...
...
torchvision/__init__.py
View file @
d367a01a
...
@@ -43,7 +43,7 @@ def set_image_backend(backend):
...
@@ -43,7 +43,7 @@ def set_image_backend(backend):
"""
"""
global
_image_backend
global
_image_backend
if
backend
not
in
[
"PIL"
,
"accimage"
]:
if
backend
not
in
[
"PIL"
,
"accimage"
]:
raise
ValueError
(
"Invalid backend '{}'. Options are 'PIL' and 'accimage'"
.
format
(
backend
)
)
raise
ValueError
(
f
"Invalid backend '
{
backend
}
'. Options are 'PIL' and 'accimage'"
)
_image_backend
=
backend
_image_backend
=
backend
...
@@ -74,7 +74,7 @@ def set_video_backend(backend):
...
@@ -74,7 +74,7 @@ def set_video_backend(backend):
if
backend
not
in
[
"pyav"
,
"video_reader"
]:
if
backend
not
in
[
"pyav"
,
"video_reader"
]:
raise
ValueError
(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'"
%
backend
)
raise
ValueError
(
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'"
%
backend
)
if
backend
==
"video_reader"
and
not
io
.
_HAS_VIDEO_OPT
:
if
backend
==
"video_reader"
and
not
io
.
_HAS_VIDEO_OPT
:
message
=
"video_reader video backend is not available.
"
"
Please compile torchvision from source and try again"
message
=
"video_reader video backend is not available. Please compile torchvision from source and try again"
warnings
.
warn
(
message
)
warnings
.
warn
(
message
)
else
:
else
:
_video_backend
=
backend
_video_backend
=
backend
...
...
torchvision/datasets/caltech.py
View file @
d367a01a
...
@@ -40,9 +40,7 @@ class Caltech101(VisionDataset):
...
@@ -40,9 +40,7 @@ class Caltech101(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
Caltech101
,
self
).
__init__
(
super
().
__init__
(
os
.
path
.
join
(
root
,
"caltech101"
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
path
.
join
(
root
,
"caltech101"
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
if
isinstance
(
target_type
,
str
):
if
isinstance
(
target_type
,
str
):
target_type
=
[
target_type
]
target_type
=
[
target_type
]
...
@@ -52,7 +50,7 @@ class Caltech101(VisionDataset):
...
@@ -52,7 +50,7 @@ class Caltech101(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"101_ObjectCategories"
)))
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"101_ObjectCategories"
)))
self
.
categories
.
remove
(
"BACKGROUND_Google"
)
# this is not a real class
self
.
categories
.
remove
(
"BACKGROUND_Google"
)
# this is not a real class
...
@@ -90,7 +88,7 @@ class Caltech101(VisionDataset):
...
@@ -90,7 +88,7 @@ class Caltech101(VisionDataset):
self
.
root
,
self
.
root
,
"101_ObjectCategories"
,
"101_ObjectCategories"
,
self
.
categories
[
self
.
y
[
index
]],
self
.
categories
[
self
.
y
[
index
]],
"image_{
:04d}.jpg"
.
format
(
self
.
index
[
index
]
)
,
f
"image_
{
self
.
index
[
index
]
:
04
d
}
.jpg"
,
)
)
)
)
...
@@ -104,7 +102,7 @@ class Caltech101(VisionDataset):
...
@@ -104,7 +102,7 @@ class Caltech101(VisionDataset):
self
.
root
,
self
.
root
,
"Annotations"
,
"Annotations"
,
self
.
annotation_categories
[
self
.
y
[
index
]],
self
.
annotation_categories
[
self
.
y
[
index
]],
"annotation_{
:04d}.mat"
.
format
(
self
.
index
[
index
]
)
,
f
"annotation_
{
self
.
index
[
index
]
:
04
d
}
.mat"
,
)
)
)
)
target
.
append
(
data
[
"obj_contour"
])
target
.
append
(
data
[
"obj_contour"
])
...
@@ -167,16 +165,14 @@ class Caltech256(VisionDataset):
...
@@ -167,16 +165,14 @@ class Caltech256(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
Caltech256
,
self
).
__init__
(
super
().
__init__
(
os
.
path
.
join
(
root
,
"caltech256"
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
path
.
join
(
root
,
"caltech256"
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
root
,
exist_ok
=
True
)
if
download
:
if
download
:
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories"
)))
self
.
categories
=
sorted
(
os
.
listdir
(
os
.
path
.
join
(
self
.
root
,
"256_ObjectCategories"
)))
self
.
index
:
List
[
int
]
=
[]
self
.
index
:
List
[
int
]
=
[]
...
@@ -205,7 +201,7 @@ class Caltech256(VisionDataset):
...
@@ -205,7 +201,7 @@ class Caltech256(VisionDataset):
self
.
root
,
self
.
root
,
"256_ObjectCategories"
,
"256_ObjectCategories"
,
self
.
categories
[
self
.
y
[
index
]],
self
.
categories
[
self
.
y
[
index
]],
"{
:03d}_{:04d}.jpg"
.
format
(
self
.
y
[
index
]
+
1
,
self
.
index
[
index
]
)
,
f
"
{
self
.
y
[
index
]
+
1
:
03
d
}
_
{
self
.
index
[
index
]
:
04
d
}
.jpg"
,
)
)
)
)
...
...
torchvision/datasets/celeba.py
View file @
d367a01a
...
@@ -66,7 +66,7 @@ class CelebA(VisionDataset):
...
@@ -66,7 +66,7 @@ class CelebA(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
CelebA
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
split
=
split
self
.
split
=
split
if
isinstance
(
target_type
,
list
):
if
isinstance
(
target_type
,
list
):
self
.
target_type
=
target_type
self
.
target_type
=
target_type
...
@@ -80,7 +80,7 @@ class CelebA(VisionDataset):
...
@@ -80,7 +80,7 @@ class CelebA(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
split_map
=
{
split_map
=
{
"train"
:
0
,
"train"
:
0
,
...
@@ -166,7 +166,7 @@ class CelebA(VisionDataset):
...
@@ -166,7 +166,7 @@ class CelebA(VisionDataset):
target
.
append
(
self
.
landmarks_align
[
index
,
:])
target
.
append
(
self
.
landmarks_align
[
index
,
:])
else
:
else
:
# TODO: refactor with utils.verify_str_arg
# TODO: refactor with utils.verify_str_arg
raise
ValueError
(
'Target type "{}" is not recognized.'
.
format
(
t
)
)
raise
ValueError
(
f
'Target type "
{
t
}
" is not recognized.'
)
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
X
=
self
.
transform
(
X
)
X
=
self
.
transform
(
X
)
...
...
torchvision/datasets/cifar.py
View file @
d367a01a
...
@@ -58,7 +58,7 @@ class CIFAR10(VisionDataset):
...
@@ -58,7 +58,7 @@ class CIFAR10(VisionDataset):
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
CIFAR10
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
train
=
train
# training set or test set
self
.
train
=
train
# training set or test set
...
@@ -66,7 +66,7 @@ class CIFAR10(VisionDataset):
...
@@ -66,7 +66,7 @@ class CIFAR10(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
if
self
.
train
:
if
self
.
train
:
downloaded_list
=
self
.
train_list
downloaded_list
=
self
.
train_list
...
@@ -95,9 +95,7 @@ class CIFAR10(VisionDataset):
...
@@ -95,9 +95,7 @@ class CIFAR10(VisionDataset):
def
_load_meta
(
self
)
->
None
:
def
_load_meta
(
self
)
->
None
:
path
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
meta
[
"filename"
])
path
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
meta
[
"filename"
])
if
not
check_integrity
(
path
,
self
.
meta
[
"md5"
]):
if
not
check_integrity
(
path
,
self
.
meta
[
"md5"
]):
raise
RuntimeError
(
raise
RuntimeError
(
"Dataset metadata file not found or corrupted. You can use download=True to download it"
)
"Dataset metadata file not found or corrupted."
+
" You can use download=True to download it"
)
with
open
(
path
,
"rb"
)
as
infile
:
with
open
(
path
,
"rb"
)
as
infile
:
data
=
pickle
.
load
(
infile
,
encoding
=
"latin1"
)
data
=
pickle
.
load
(
infile
,
encoding
=
"latin1"
)
self
.
classes
=
data
[
self
.
meta
[
"key"
]]
self
.
classes
=
data
[
self
.
meta
[
"key"
]]
...
@@ -144,7 +142,8 @@ class CIFAR10(VisionDataset):
...
@@ -144,7 +142,8 @@ class CIFAR10(VisionDataset):
download_and_extract_archive
(
self
.
url
,
self
.
root
,
filename
=
self
.
filename
,
md5
=
self
.
tgz_md5
)
download_and_extract_archive
(
self
.
url
,
self
.
root
,
filename
=
self
.
filename
,
md5
=
self
.
tgz_md5
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
split
=
"Train"
if
self
.
train
is
True
else
"Test"
return
f
"Split:
{
split
}
"
class
CIFAR100
(
CIFAR10
):
class
CIFAR100
(
CIFAR10
):
...
...
torchvision/datasets/cityscapes.py
View file @
d367a01a
...
@@ -111,7 +111,7 @@ class Cityscapes(VisionDataset):
...
@@ -111,7 +111,7 @@ class Cityscapes(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
super
().
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
"gtFine"
if
mode
==
"fine"
else
"gtCoarse"
self
.
mode
=
"gtFine"
if
mode
==
"fine"
else
"gtCoarse"
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit"
,
split
)
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit"
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
mode
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
mode
,
split
)
...
@@ -125,7 +125,7 @@ class Cityscapes(VisionDataset):
...
@@ -125,7 +125,7 @@ class Cityscapes(VisionDataset):
valid_modes
=
(
"train"
,
"test"
,
"val"
)
valid_modes
=
(
"train"
,
"test"
,
"val"
)
else
:
else
:
valid_modes
=
(
"train"
,
"train_extra"
,
"val"
)
valid_modes
=
(
"train"
,
"train_extra"
,
"val"
)
msg
=
"Unknown value '{}' for argument split if mode is '{}'.
"
"
Valid values are {{{}}}."
msg
=
"Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
msg
=
msg
.
format
(
split
,
mode
,
iterable_to_str
(
valid_modes
))
msg
=
msg
.
format
(
split
,
mode
,
iterable_to_str
(
valid_modes
))
verify_str_arg
(
split
,
"split"
,
valid_modes
,
msg
)
verify_str_arg
(
split
,
"split"
,
valid_modes
,
msg
)
...
@@ -139,14 +139,14 @@ class Cityscapes(VisionDataset):
...
@@ -139,14 +139,14 @@ class Cityscapes(VisionDataset):
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
if
not
os
.
path
.
isdir
(
self
.
images_dir
)
or
not
os
.
path
.
isdir
(
self
.
targets_dir
):
if
split
==
"train_extra"
:
if
split
==
"train_extra"
:
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit
{}"
.
format
(
"
_trainextra.zip"
)
)
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit_trainextra.zip"
)
else
:
else
:
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit
{}"
.
format
(
"
_trainvaltest.zip"
)
)
image_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"leftImg8bit_trainvaltest.zip"
)
if
self
.
mode
==
"gtFine"
:
if
self
.
mode
==
"gtFine"
:
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"{
}{}"
.
format
(
self
.
mode
,
"
_trainvaltest.zip"
)
)
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
f
"
{
self
.
mode
}
_trainvaltest.zip"
)
elif
self
.
mode
==
"gtCoarse"
:
elif
self
.
mode
==
"gtCoarse"
:
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
"{
}{}"
.
format
(
self
.
mode
,
"
.zip"
)
)
target_dir_zip
=
os
.
path
.
join
(
self
.
root
,
f
"
{
self
.
mode
}
.zip"
)
if
os
.
path
.
isfile
(
image_dir_zip
)
and
os
.
path
.
isfile
(
target_dir_zip
):
if
os
.
path
.
isfile
(
image_dir_zip
)
and
os
.
path
.
isfile
(
target_dir_zip
):
extract_archive
(
from_path
=
image_dir_zip
,
to_path
=
self
.
root
)
extract_archive
(
from_path
=
image_dir_zip
,
to_path
=
self
.
root
)
...
@@ -206,16 +206,16 @@ class Cityscapes(VisionDataset):
...
@@ -206,16 +206,16 @@ class Cityscapes(VisionDataset):
return
"
\n
"
.
join
(
lines
).
format
(
**
self
.
__dict__
)
return
"
\n
"
.
join
(
lines
).
format
(
**
self
.
__dict__
)
def
_load_json
(
self
,
path
:
str
)
->
Dict
[
str
,
Any
]:
def
_load_json
(
self
,
path
:
str
)
->
Dict
[
str
,
Any
]:
with
open
(
path
,
"r"
)
as
file
:
with
open
(
path
)
as
file
:
data
=
json
.
load
(
file
)
data
=
json
.
load
(
file
)
return
data
return
data
def
_get_target_suffix
(
self
,
mode
:
str
,
target_type
:
str
)
->
str
:
def
_get_target_suffix
(
self
,
mode
:
str
,
target_type
:
str
)
->
str
:
if
target_type
==
"instance"
:
if
target_type
==
"instance"
:
return
"{}_instanceIds.png"
.
format
(
mode
)
return
f
"
{
mode
}
_instanceIds.png"
elif
target_type
==
"semantic"
:
elif
target_type
==
"semantic"
:
return
"{}_labelIds.png"
.
format
(
mode
)
return
f
"
{
mode
}
_labelIds.png"
elif
target_type
==
"color"
:
elif
target_type
==
"color"
:
return
"{}_color.png"
.
format
(
mode
)
return
f
"
{
mode
}
_color.png"
else
:
else
:
return
"{}_polygons.json"
.
format
(
mode
)
return
f
"
{
mode
}
_polygons.json"
torchvision/datasets/fakedata.py
View file @
d367a01a
...
@@ -31,9 +31,7 @@ class FakeData(VisionDataset):
...
@@ -31,9 +31,7 @@ class FakeData(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
random_offset
:
int
=
0
,
random_offset
:
int
=
0
,
)
->
None
:
)
->
None
:
super
(
FakeData
,
self
).
__init__
(
super
().
__init__
(
None
,
transform
=
transform
,
target_transform
=
target_transform
)
# type: ignore[arg-type]
None
,
transform
=
transform
,
target_transform
=
target_transform
# type: ignore[arg-type]
)
self
.
size
=
size
self
.
size
=
size
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
image_size
=
image_size
self
.
image_size
=
image_size
...
@@ -49,7 +47,7 @@ class FakeData(VisionDataset):
...
@@ -49,7 +47,7 @@ class FakeData(VisionDataset):
"""
"""
# create random image that is consistent with the index id
# create random image that is consistent with the index id
if
index
>=
len
(
self
):
if
index
>=
len
(
self
):
raise
IndexError
(
"{
} index out of range"
.
format
(
self
.
__class__
.
__name__
)
)
raise
IndexError
(
f
"
{
self
.
__class__
.
__name__
}
index out of range"
)
rng_state
=
torch
.
get_rng_state
()
rng_state
=
torch
.
get_rng_state
()
torch
.
manual_seed
(
index
+
self
.
random_offset
)
torch
.
manual_seed
(
index
+
self
.
random_offset
)
img
=
torch
.
randn
(
*
self
.
image_size
)
img
=
torch
.
randn
(
*
self
.
image_size
)
...
...
torchvision/datasets/flickr.py
View file @
d367a01a
...
@@ -13,7 +13,7 @@ class Flickr8kParser(HTMLParser):
...
@@ -13,7 +13,7 @@ class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
"""Parser for extracting captions from the Flickr8k dataset web page."""
def
__init__
(
self
,
root
:
str
)
->
None
:
def
__init__
(
self
,
root
:
str
)
->
None
:
super
(
Flickr8kParser
,
self
).
__init__
()
super
().
__init__
()
self
.
root
=
root
self
.
root
=
root
...
@@ -71,7 +71,7 @@ class Flickr8k(VisionDataset):
...
@@ -71,7 +71,7 @@ class Flickr8k(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
super
(
Flickr8k
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
ann_file
=
os
.
path
.
expanduser
(
ann_file
)
self
.
ann_file
=
os
.
path
.
expanduser
(
ann_file
)
# Read annotations and store in a dict
# Read annotations and store in a dict
...
@@ -127,7 +127,7 @@ class Flickr30k(VisionDataset):
...
@@ -127,7 +127,7 @@ class Flickr30k(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
super
(
Flickr30k
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
ann_file
=
os
.
path
.
expanduser
(
ann_file
)
self
.
ann_file
=
os
.
path
.
expanduser
(
ann_file
)
# Read annotations and store in a dict
# Read annotations and store in a dict
...
...
torchvision/datasets/folder.py
View file @
d367a01a
...
@@ -140,7 +140,7 @@ class DatasetFolder(VisionDataset):
...
@@ -140,7 +140,7 @@ class DatasetFolder(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
None
:
)
->
None
:
super
(
DatasetFolder
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
classes
,
class_to_idx
=
self
.
find_classes
(
self
.
root
)
classes
,
class_to_idx
=
self
.
find_classes
(
self
.
root
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
,
extensions
,
is_valid_file
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
,
extensions
,
is_valid_file
)
...
@@ -254,7 +254,7 @@ def accimage_loader(path: str) -> Any:
...
@@ -254,7 +254,7 @@ def accimage_loader(path: str) -> Any:
try
:
try
:
return
accimage
.
Image
(
path
)
return
accimage
.
Image
(
path
)
except
I
OError
:
except
O
S
Error
:
# Potentially a decoding problem, fall back to PIL.Image
# Potentially a decoding problem, fall back to PIL.Image
return
pil_loader
(
path
)
return
pil_loader
(
path
)
...
@@ -306,7 +306,7 @@ class ImageFolder(DatasetFolder):
...
@@ -306,7 +306,7 @@ class ImageFolder(DatasetFolder):
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
):
super
(
ImageFolder
,
self
).
__init__
(
super
().
__init__
(
root
,
root
,
loader
,
loader
,
IMG_EXTENSIONS
if
is_valid_file
is
None
else
None
,
IMG_EXTENSIONS
if
is_valid_file
is
None
else
None
,
...
...
torchvision/datasets/hmdb51.py
View file @
d367a01a
...
@@ -72,9 +72,9 @@ class HMDB51(VisionDataset):
...
@@ -72,9 +72,9 @@ class HMDB51(VisionDataset):
_video_min_dimension
:
int
=
0
,
_video_min_dimension
:
int
=
0
,
_audio_samples
:
int
=
0
,
_audio_samples
:
int
=
0
,
)
->
None
:
)
->
None
:
super
(
HMDB51
,
self
).
__init__
(
root
)
super
().
__init__
(
root
)
if
fold
not
in
(
1
,
2
,
3
):
if
fold
not
in
(
1
,
2
,
3
):
raise
ValueError
(
"fold should be between 1 and 3, got {
}"
.
format
(
fold
)
)
raise
ValueError
(
f
"fold should be between 1 and 3, got
{
fold
}
"
)
extensions
=
(
"avi"
,)
extensions
=
(
"avi"
,)
self
.
classes
,
class_to_idx
=
find_classes
(
self
.
root
)
self
.
classes
,
class_to_idx
=
find_classes
(
self
.
root
)
...
@@ -113,7 +113,7 @@ class HMDB51(VisionDataset):
...
@@ -113,7 +113,7 @@ class HMDB51(VisionDataset):
def
_select_fold
(
self
,
video_list
:
List
[
str
],
annotations_dir
:
str
,
fold
:
int
,
train
:
bool
)
->
List
[
int
]:
def
_select_fold
(
self
,
video_list
:
List
[
str
],
annotations_dir
:
str
,
fold
:
int
,
train
:
bool
)
->
List
[
int
]:
target_tag
=
self
.
TRAIN_TAG
if
train
else
self
.
TEST_TAG
target_tag
=
self
.
TRAIN_TAG
if
train
else
self
.
TEST_TAG
split_pattern_name
=
"*test_split{}.txt"
.
format
(
fold
)
split_pattern_name
=
f
"*test_split
{
fold
}
.txt"
split_pattern_path
=
os
.
path
.
join
(
annotations_dir
,
split_pattern_name
)
split_pattern_path
=
os
.
path
.
join
(
annotations_dir
,
split_pattern_name
)
annotation_paths
=
glob
.
glob
(
split_pattern_path
)
annotation_paths
=
glob
.
glob
(
split_pattern_path
)
selected_files
=
set
()
selected_files
=
set
()
...
...
torchvision/datasets/imagenet.py
View file @
d367a01a
...
@@ -49,7 +49,7 @@ class ImageNet(ImageFolder):
...
@@ -49,7 +49,7 @@ class ImageNet(ImageFolder):
)
)
raise
RuntimeError
(
msg
)
raise
RuntimeError
(
msg
)
elif
download
is
False
:
elif
download
is
False
:
msg
=
"The use of the download flag is deprecated, since the dataset
"
"
is no longer publicly accessible."
msg
=
"The use of the download flag is deprecated, since the dataset is no longer publicly accessible."
warnings
.
warn
(
msg
,
RuntimeWarning
)
warnings
.
warn
(
msg
,
RuntimeWarning
)
root
=
self
.
root
=
os
.
path
.
expanduser
(
root
)
root
=
self
.
root
=
os
.
path
.
expanduser
(
root
)
...
@@ -58,7 +58,7 @@ class ImageNet(ImageFolder):
...
@@ -58,7 +58,7 @@ class ImageNet(ImageFolder):
self
.
parse_archives
()
self
.
parse_archives
()
wnid_to_classes
=
load_meta_file
(
self
.
root
)[
0
]
wnid_to_classes
=
load_meta_file
(
self
.
root
)[
0
]
super
(
ImageNet
,
self
).
__init__
(
self
.
split_folder
,
**
kwargs
)
super
().
__init__
(
self
.
split_folder
,
**
kwargs
)
self
.
root
=
root
self
.
root
=
root
self
.
wnids
=
self
.
classes
self
.
wnids
=
self
.
classes
...
@@ -132,7 +132,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
...
@@ -132,7 +132,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def
parse_val_groundtruth_txt
(
devkit_root
:
str
)
->
List
[
int
]:
def
parse_val_groundtruth_txt
(
devkit_root
:
str
)
->
List
[
int
]:
file
=
os
.
path
.
join
(
devkit_root
,
"data"
,
"ILSVRC2012_validation_ground_truth.txt"
)
file
=
os
.
path
.
join
(
devkit_root
,
"data"
,
"ILSVRC2012_validation_ground_truth.txt"
)
with
open
(
file
,
"r"
)
as
txtfh
:
with
open
(
file
)
as
txtfh
:
val_idcs
=
txtfh
.
readlines
()
val_idcs
=
txtfh
.
readlines
()
return
[
int
(
val_idx
)
for
val_idx
in
val_idcs
]
return
[
int
(
val_idx
)
for
val_idx
in
val_idcs
]
...
@@ -215,7 +215,7 @@ def parse_val_archive(
...
@@ -215,7 +215,7 @@ def parse_val_archive(
val_root
=
os
.
path
.
join
(
root
,
folder
)
val_root
=
os
.
path
.
join
(
root
,
folder
)
extract_archive
(
os
.
path
.
join
(
root
,
file
),
val_root
)
extract_archive
(
os
.
path
.
join
(
root
,
file
),
val_root
)
images
=
sorted
(
[
os
.
path
.
join
(
val_root
,
image
)
for
image
in
os
.
listdir
(
val_root
)
]
)
images
=
sorted
(
os
.
path
.
join
(
val_root
,
image
)
for
image
in
os
.
listdir
(
val_root
))
for
wnid
in
set
(
wnids
):
for
wnid
in
set
(
wnids
):
os
.
mkdir
(
os
.
path
.
join
(
val_root
,
wnid
))
os
.
mkdir
(
os
.
path
.
join
(
val_root
,
wnid
))
...
...
torchvision/datasets/inaturalist.py
View file @
d367a01a
...
@@ -74,16 +74,14 @@ class INaturalist(VisionDataset):
...
@@ -74,16 +74,14 @@ class INaturalist(VisionDataset):
)
->
None
:
)
->
None
:
self
.
version
=
verify_str_arg
(
version
,
"version"
,
DATASET_URLS
.
keys
())
self
.
version
=
verify_str_arg
(
version
,
"version"
,
DATASET_URLS
.
keys
())
super
(
INaturalist
,
self
).
__init__
(
super
().
__init__
(
os
.
path
.
join
(
root
,
version
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
path
.
join
(
root
,
version
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
makedirs
(
root
,
exist_ok
=
True
)
os
.
makedirs
(
root
,
exist_ok
=
True
)
if
download
:
if
download
:
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
all_categories
:
List
[
str
]
=
[]
self
.
all_categories
:
List
[
str
]
=
[]
...
...
torchvision/datasets/kinetics.py
View file @
d367a01a
...
@@ -175,7 +175,7 @@ class Kinetics(VisionDataset):
...
@@ -175,7 +175,7 @@ class Kinetics(VisionDataset):
split_url_filepath
=
path
.
join
(
file_list_path
,
path
.
basename
(
split_url
))
split_url_filepath
=
path
.
join
(
file_list_path
,
path
.
basename
(
split_url
))
if
not
check_integrity
(
split_url_filepath
):
if
not
check_integrity
(
split_url_filepath
):
download_url
(
split_url
,
file_list_path
)
download_url
(
split_url
,
file_list_path
)
list_video_urls
=
open
(
split_url_filepath
,
"r"
)
list_video_urls
=
open
(
split_url_filepath
)
if
self
.
num_download_workers
==
1
:
if
self
.
num_download_workers
==
1
:
for
line
in
list_video_urls
.
readlines
():
for
line
in
list_video_urls
.
readlines
():
...
@@ -309,7 +309,7 @@ class Kinetics400(Kinetics):
...
@@ -309,7 +309,7 @@ class Kinetics400(Kinetics):
"Kinetics400. Please use Kinetics instead."
"Kinetics400. Please use Kinetics instead."
)
)
super
(
Kinetics400
,
self
).
__init__
(
super
().
__init__
(
root
=
root
,
root
=
root
,
frames_per_clip
=
frames_per_clip
,
frames_per_clip
=
frames_per_clip
,
_legacy
=
True
,
_legacy
=
True
,
...
...
torchvision/datasets/lfw.py
View file @
d367a01a
...
@@ -39,9 +39,7 @@ class _LFW(VisionDataset):
...
@@ -39,9 +39,7 @@ class _LFW(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
):
):
super
(
_LFW
,
self
).
__init__
(
super
().
__init__
(
os
.
path
.
join
(
root
,
self
.
base_folder
),
transform
=
transform
,
target_transform
=
target_transform
)
os
.
path
.
join
(
root
,
self
.
base_folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
image_set
=
verify_str_arg
(
image_set
.
lower
(),
"image_set"
,
self
.
file_dict
.
keys
())
self
.
image_set
=
verify_str_arg
(
image_set
.
lower
(),
"image_set"
,
self
.
file_dict
.
keys
())
images_dir
,
self
.
filename
,
self
.
md5
=
self
.
file_dict
[
self
.
image_set
]
images_dir
,
self
.
filename
,
self
.
md5
=
self
.
file_dict
[
self
.
image_set
]
...
@@ -55,7 +53,7 @@ class _LFW(VisionDataset):
...
@@ -55,7 +53,7 @@ class _LFW(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
images_dir
)
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
images_dir
)
...
@@ -122,14 +120,14 @@ class LFWPeople(_LFW):
...
@@ -122,14 +120,14 @@ class LFWPeople(_LFW):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
):
):
super
(
LFWPeople
,
self
).
__init__
(
root
,
split
,
image_set
,
"people"
,
transform
,
target_transform
,
download
)
super
().
__init__
(
root
,
split
,
image_set
,
"people"
,
transform
,
target_transform
,
download
)
self
.
class_to_idx
=
self
.
_get_classes
()
self
.
class_to_idx
=
self
.
_get_classes
()
self
.
data
,
self
.
targets
=
self
.
_get_people
()
self
.
data
,
self
.
targets
=
self
.
_get_people
()
def
_get_people
(
self
):
def
_get_people
(
self
):
data
,
targets
=
[],
[]
data
,
targets
=
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
n_folds
,
s
=
(
int
(
lines
[
0
]),
1
)
if
self
.
split
==
"10fold"
else
(
1
,
0
)
n_folds
,
s
=
(
int
(
lines
[
0
]),
1
)
if
self
.
split
==
"10fold"
else
(
1
,
0
)
...
@@ -146,7 +144,7 @@ class LFWPeople(_LFW):
...
@@ -146,7 +144,7 @@ class LFWPeople(_LFW):
return
data
,
targets
return
data
,
targets
def
_get_classes
(
self
):
def
_get_classes
(
self
):
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
names
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
names
))
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
names
=
[
line
.
strip
().
split
()[
0
]
for
line
in
lines
]
names
=
[
line
.
strip
().
split
()[
0
]
for
line
in
lines
]
class_to_idx
=
{
name
:
i
for
i
,
name
in
enumerate
(
names
)}
class_to_idx
=
{
name
:
i
for
i
,
name
in
enumerate
(
names
)}
...
@@ -172,7 +170,7 @@ class LFWPeople(_LFW):
...
@@ -172,7 +170,7 @@ class LFWPeople(_LFW):
return
img
,
target
return
img
,
target
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
super
().
extra_repr
()
+
"
\n
Classes (identities): {
}"
.
format
(
len
(
self
.
class_to_idx
)
)
return
super
().
extra_repr
()
+
f
"
\n
Classes (identities):
{
len
(
self
.
class_to_idx
)
}
"
class
LFWPairs
(
_LFW
):
class
LFWPairs
(
_LFW
):
...
@@ -204,13 +202,13 @@ class LFWPairs(_LFW):
...
@@ -204,13 +202,13 @@ class LFWPairs(_LFW):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
):
):
super
(
LFWPairs
,
self
).
__init__
(
root
,
split
,
image_set
,
"pairs"
,
transform
,
target_transform
,
download
)
super
().
__init__
(
root
,
split
,
image_set
,
"pairs"
,
transform
,
target_transform
,
download
)
self
.
pair_names
,
self
.
data
,
self
.
targets
=
self
.
_get_pairs
(
self
.
images_dir
)
self
.
pair_names
,
self
.
data
,
self
.
targets
=
self
.
_get_pairs
(
self
.
images_dir
)
def
_get_pairs
(
self
,
images_dir
):
def
_get_pairs
(
self
,
images_dir
):
pair_names
,
data
,
targets
=
[],
[],
[]
pair_names
,
data
,
targets
=
[],
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
if
self
.
split
==
"10fold"
:
if
self
.
split
==
"10fold"
:
n_folds
,
n_pairs
=
lines
[
0
].
split
(
"
\t
"
)
n_folds
,
n_pairs
=
lines
[
0
].
split
(
"
\t
"
)
...
...
torchvision/datasets/lsun.py
View file @
d367a01a
...
@@ -18,7 +18,7 @@ class LSUNClass(VisionDataset):
...
@@ -18,7 +18,7 @@ class LSUNClass(VisionDataset):
)
->
None
:
)
->
None
:
import
lmdb
import
lmdb
super
(
LSUNClass
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
env
=
lmdb
.
open
(
root
,
max_readers
=
1
,
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
meminit
=
False
)
self
.
env
=
lmdb
.
open
(
root
,
max_readers
=
1
,
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
meminit
=
False
)
with
self
.
env
.
begin
(
write
=
False
)
as
txn
:
with
self
.
env
.
begin
(
write
=
False
)
as
txn
:
...
@@ -77,7 +77,7 @@ class LSUN(VisionDataset):
...
@@ -77,7 +77,7 @@ class LSUN(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
super
(
LSUN
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
classes
=
self
.
_verify_classes
(
classes
)
self
.
classes
=
self
.
_verify_classes
(
classes
)
# for each class, create an LSUNClassDataset
# for each class, create an LSUNClassDataset
...
@@ -117,11 +117,11 @@ class LSUN(VisionDataset):
...
@@ -117,11 +117,11 @@ class LSUN(VisionDataset):
classes
=
[
c
+
"_"
+
classes
for
c
in
categories
]
classes
=
[
c
+
"_"
+
classes
for
c
in
categories
]
except
ValueError
:
except
ValueError
:
if
not
isinstance
(
classes
,
Iterable
):
if
not
isinstance
(
classes
,
Iterable
):
msg
=
"Expected type str or Iterable for argument classes,
"
"
but got type {}."
msg
=
"Expected type str or Iterable for argument classes, but got type {}."
raise
ValueError
(
msg
.
format
(
type
(
classes
)))
raise
ValueError
(
msg
.
format
(
type
(
classes
)))
classes
=
list
(
classes
)
classes
=
list
(
classes
)
msg_fmtstr_type
=
"Expected type str for elements in argument classes,
"
"
but got type {}."
msg_fmtstr_type
=
"Expected type str for elements in argument classes, but got type {}."
for
c
in
classes
:
for
c
in
classes
:
verify_str_arg
(
c
,
custom_msg
=
msg_fmtstr_type
.
format
(
type
(
c
)))
verify_str_arg
(
c
,
custom_msg
=
msg_fmtstr_type
.
format
(
type
(
c
)))
c_short
=
c
.
split
(
"_"
)
c_short
=
c
.
split
(
"_"
)
...
...
torchvision/datasets/mnist.py
View file @
d367a01a
...
@@ -88,7 +88,7 @@ class MNIST(VisionDataset):
...
@@ -88,7 +88,7 @@ class MNIST(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
MNIST
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
train
=
train
# training set or test set
self
.
train
=
train
# training set or test set
if
self
.
_check_legacy_exist
():
if
self
.
_check_legacy_exist
():
...
@@ -99,7 +99,7 @@ class MNIST(VisionDataset):
...
@@ -99,7 +99,7 @@ class MNIST(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_exists
():
if
not
self
.
_check_exists
():
raise
RuntimeError
(
"Dataset not found.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found. You can use download=True to download it"
)
self
.
data
,
self
.
targets
=
self
.
_load_data
()
self
.
data
,
self
.
targets
=
self
.
_load_data
()
...
@@ -181,21 +181,22 @@ class MNIST(VisionDataset):
...
@@ -181,21 +181,22 @@ class MNIST(VisionDataset):
# download files
# download files
for
filename
,
md5
in
self
.
resources
:
for
filename
,
md5
in
self
.
resources
:
for
mirror
in
self
.
mirrors
:
for
mirror
in
self
.
mirrors
:
url
=
"{
}{}"
.
format
(
mirror
,
filename
)
url
=
f
"
{
mirror
}{
filename
}
"
try
:
try
:
print
(
"Downloading {
}"
.
format
(
url
)
)
print
(
f
"Downloading
{
url
}
"
)
download_and_extract_archive
(
url
,
download_root
=
self
.
raw_folder
,
filename
=
filename
,
md5
=
md5
)
download_and_extract_archive
(
url
,
download_root
=
self
.
raw_folder
,
filename
=
filename
,
md5
=
md5
)
except
URLError
as
error
:
except
URLError
as
error
:
print
(
"Failed to download (trying next):
\n
{
}"
.
format
(
error
)
)
print
(
f
"Failed to download (trying next):
\n
{
error
}
"
)
continue
continue
finally
:
finally
:
print
()
print
()
break
break
else
:
else
:
raise
RuntimeError
(
"Error downloading {
}"
.
format
(
filename
)
)
raise
RuntimeError
(
f
"Error downloading
{
filename
}
"
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
split
=
"Train"
if
self
.
train
is
True
else
"Test"
return
f
"Split:
{
split
}
"
class
FashionMNIST
(
MNIST
):
class
FashionMNIST
(
MNIST
):
...
@@ -293,16 +294,16 @@ class EMNIST(MNIST):
...
@@ -293,16 +294,16 @@ class EMNIST(MNIST):
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
training_file
=
self
.
_training_file
(
split
)
self
.
training_file
=
self
.
_training_file
(
split
)
self
.
test_file
=
self
.
_test_file
(
split
)
self
.
test_file
=
self
.
_test_file
(
split
)
super
(
EMNIST
,
self
).
__init__
(
root
,
**
kwargs
)
super
().
__init__
(
root
,
**
kwargs
)
self
.
classes
=
self
.
classes_split_dict
[
self
.
split
]
self
.
classes
=
self
.
classes_split_dict
[
self
.
split
]
@
staticmethod
@
staticmethod
def
_training_file
(
split
)
->
str
:
def
_training_file
(
split
)
->
str
:
return
"training_{}.pt"
.
format
(
split
)
return
f
"training_
{
split
}
.pt"
@
staticmethod
@
staticmethod
def
_test_file
(
split
)
->
str
:
def
_test_file
(
split
)
->
str
:
return
"test_{}.pt"
.
format
(
split
)
return
f
"test_
{
split
}
.pt"
@
property
@
property
def
_file_prefix
(
self
)
->
str
:
def
_file_prefix
(
self
)
->
str
:
...
@@ -424,7 +425,7 @@ class QMNIST(MNIST):
...
@@ -424,7 +425,7 @@ class QMNIST(MNIST):
self
.
data_file
=
what
+
".pt"
self
.
data_file
=
what
+
".pt"
self
.
training_file
=
self
.
data_file
self
.
training_file
=
self
.
data_file
self
.
test_file
=
self
.
data_file
self
.
test_file
=
self
.
data_file
super
(
QMNIST
,
self
).
__init__
(
root
,
train
,
**
kwargs
)
super
().
__init__
(
root
,
train
,
**
kwargs
)
@
property
@
property
def
images_file
(
self
)
->
str
:
def
images_file
(
self
)
->
str
:
...
@@ -482,7 +483,7 @@ class QMNIST(MNIST):
...
@@ -482,7 +483,7 @@ class QMNIST(MNIST):
return
img
,
target
return
img
,
target
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
"Split: {
}"
.
format
(
self
.
what
)
return
f
"Split:
{
self
.
what
}
"
def
get_int
(
b
:
bytes
)
->
int
:
def
get_int
(
b
:
bytes
)
->
int
:
...
...
torchvision/datasets/omniglot.py
View file @
d367a01a
...
@@ -39,19 +39,19 @@ class Omniglot(VisionDataset):
...
@@ -39,19 +39,19 @@ class Omniglot(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
Omniglot
,
self
).
__init__
(
join
(
root
,
self
.
folder
),
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
join
(
root
,
self
.
folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
background
=
background
self
.
background
=
background
if
download
:
if
download
:
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
self
.
target_folder
=
join
(
self
.
root
,
self
.
_get_target_folder
())
self
.
target_folder
=
join
(
self
.
root
,
self
.
_get_target_folder
())
self
.
_alphabets
=
list_dir
(
self
.
target_folder
)
self
.
_alphabets
=
list_dir
(
self
.
target_folder
)
self
.
_characters
:
List
[
str
]
=
sum
(
self
.
_characters
:
List
[
str
]
=
sum
(
[
[
join
(
a
,
c
)
for
c
in
list_dir
(
join
(
self
.
target_folder
,
a
))]
for
a
in
self
.
_alphabets
]
,
[]
(
[
join
(
a
,
c
)
for
c
in
list_dir
(
join
(
self
.
target_folder
,
a
))]
for
a
in
self
.
_alphabets
)
,
[]
)
)
self
.
_character_images
=
[
self
.
_character_images
=
[
[(
image
,
idx
)
for
image
in
list_files
(
join
(
self
.
target_folder
,
character
),
".png"
)]
[(
image
,
idx
)
for
image
in
list_files
(
join
(
self
.
target_folder
,
character
),
".png"
)]
...
...
torchvision/datasets/phototour.py
View file @
d367a01a
...
@@ -89,11 +89,11 @@ class PhotoTour(VisionDataset):
...
@@ -89,11 +89,11 @@ class PhotoTour(VisionDataset):
def
__init__
(
def
__init__
(
self
,
root
:
str
,
name
:
str
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
self
,
root
:
str
,
name
:
str
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
)
->
None
:
)
->
None
:
super
(
PhotoTour
,
self
).
__init__
(
root
,
transform
=
transform
)
super
().
__init__
(
root
,
transform
=
transform
)
self
.
name
=
name
self
.
name
=
name
self
.
data_dir
=
os
.
path
.
join
(
self
.
root
,
name
)
self
.
data_dir
=
os
.
path
.
join
(
self
.
root
,
name
)
self
.
data_down
=
os
.
path
.
join
(
self
.
root
,
"{}.zip"
.
format
(
name
)
)
self
.
data_down
=
os
.
path
.
join
(
self
.
root
,
f
"
{
name
}
.zip"
)
self
.
data_file
=
os
.
path
.
join
(
self
.
root
,
"{}.pt"
.
format
(
name
)
)
self
.
data_file
=
os
.
path
.
join
(
self
.
root
,
f
"
{
name
}
.pt"
)
self
.
train
=
train
self
.
train
=
train
self
.
mean
=
self
.
means
[
name
]
self
.
mean
=
self
.
means
[
name
]
...
@@ -139,7 +139,7 @@ class PhotoTour(VisionDataset):
...
@@ -139,7 +139,7 @@ class PhotoTour(VisionDataset):
def
download
(
self
)
->
None
:
def
download
(
self
)
->
None
:
if
self
.
_check_datafile_exists
():
if
self
.
_check_datafile_exists
():
print
(
"# Found cached data {
}"
.
format
(
self
.
data_file
)
)
print
(
f
"# Found cached data
{
self
.
data_file
}
"
)
return
return
if
not
self
.
_check_downloaded
():
if
not
self
.
_check_downloaded
():
...
@@ -151,7 +151,7 @@ class PhotoTour(VisionDataset):
...
@@ -151,7 +151,7 @@ class PhotoTour(VisionDataset):
download_url
(
url
,
self
.
root
,
filename
,
md5
)
download_url
(
url
,
self
.
root
,
filename
,
md5
)
print
(
"# Extracting data {
}
\n
"
.
format
(
self
.
data_down
)
)
print
(
f
"# Extracting data
{
self
.
data_down
}
\n
"
)
import
zipfile
import
zipfile
...
@@ -162,7 +162,7 @@ class PhotoTour(VisionDataset):
...
@@ -162,7 +162,7 @@ class PhotoTour(VisionDataset):
def
cache
(
self
)
->
None
:
def
cache
(
self
)
->
None
:
# process and save as torch files
# process and save as torch files
print
(
"# Caching data {
}"
.
format
(
self
.
data_file
)
)
print
(
f
"# Caching data
{
self
.
data_file
}
"
)
dataset
=
(
dataset
=
(
read_image_file
(
self
.
data_dir
,
self
.
image_ext
,
self
.
lens
[
self
.
name
]),
read_image_file
(
self
.
data_dir
,
self
.
image_ext
,
self
.
lens
[
self
.
name
]),
...
@@ -174,7 +174,8 @@ class PhotoTour(VisionDataset):
...
@@ -174,7 +174,8 @@ class PhotoTour(VisionDataset):
torch
.
save
(
dataset
,
f
)
torch
.
save
(
dataset
,
f
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
split
=
"Train"
if
self
.
train
is
True
else
"Test"
return
f
"Split:
{
split
}
"
def
read_image_file
(
data_dir
:
str
,
image_ext
:
str
,
n
:
int
)
->
torch
.
Tensor
:
def
read_image_file
(
data_dir
:
str
,
image_ext
:
str
,
n
:
int
)
->
torch
.
Tensor
:
...
@@ -209,7 +210,7 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
...
@@ -209,7 +210,7 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels
"""Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point.
Read the file and keep only the ID of the 3D point.
"""
"""
with
open
(
os
.
path
.
join
(
data_dir
,
info_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
data_dir
,
info_file
))
as
f
:
labels
=
[
int
(
line
.
split
()[
0
])
for
line
in
f
]
labels
=
[
int
(
line
.
split
()[
0
])
for
line
in
f
]
return
torch
.
LongTensor
(
labels
)
return
torch
.
LongTensor
(
labels
)
...
@@ -220,7 +221,7 @@ def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
...
@@ -220,7 +221,7 @@ def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
Matches are represented with a 1, non matches with a 0.
Matches are represented with a 1, non matches with a 0.
"""
"""
matches
=
[]
matches
=
[]
with
open
(
os
.
path
.
join
(
data_dir
,
matches_file
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
data_dir
,
matches_file
))
as
f
:
for
line
in
f
:
for
line
in
f
:
line_split
=
line
.
split
()
line_split
=
line
.
split
()
matches
.
append
([
int
(
line_split
[
0
]),
int
(
line_split
[
3
]),
int
(
line_split
[
1
]
==
line_split
[
4
])])
matches
.
append
([
int
(
line_split
[
0
]),
int
(
line_split
[
3
]),
int
(
line_split
[
1
]
==
line_split
[
4
])])
...
...
torchvision/datasets/places365.py
View file @
d367a01a
...
@@ -117,7 +117,7 @@ class Places365(VisionDataset):
...
@@ -117,7 +117,7 @@ class Places365(VisionDataset):
if
not
self
.
_check_integrity
(
file
,
md5
,
download
):
if
not
self
.
_check_integrity
(
file
,
md5
,
download
):
self
.
download_devkit
()
self
.
download_devkit
()
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
)
as
fh
:
class_to_idx
=
dict
(
process
(
line
)
for
line
in
fh
)
class_to_idx
=
dict
(
process
(
line
)
for
line
in
fh
)
return
sorted
(
class_to_idx
.
keys
()),
class_to_idx
return
sorted
(
class_to_idx
.
keys
()),
class_to_idx
...
@@ -132,7 +132,7 @@ class Places365(VisionDataset):
...
@@ -132,7 +132,7 @@ class Places365(VisionDataset):
if
not
self
.
_check_integrity
(
file
,
md5
,
download
):
if
not
self
.
_check_integrity
(
file
,
md5
,
download
):
self
.
download_devkit
()
self
.
download_devkit
()
with
open
(
file
,
"r"
)
as
fh
:
with
open
(
file
)
as
fh
:
images
=
[
process
(
line
)
for
line
in
fh
]
images
=
[
process
(
line
)
for
line
in
fh
]
_
,
targets
=
zip
(
*
images
)
_
,
targets
=
zip
(
*
images
)
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment