Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d367a01a
Unverified
Commit
d367a01a
authored
Oct 28, 2021
by
Jirka Borovec
Committed by
GitHub
Oct 28, 2021
Browse files
Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)
Co-authored-by:
Nicolas Hug
<
nicolashug@fb.com
>
parent
50dfe207
Changes
136
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
55 additions
and
60 deletions
+55
-60
torchvision/datasets/samplers/clip_sampler.py
torchvision/datasets/samplers/clip_sampler.py
+3
-3
torchvision/datasets/sbd.py
torchvision/datasets/sbd.py
+4
-4
torchvision/datasets/sbu.py
torchvision/datasets/sbu.py
+2
-2
torchvision/datasets/semeion.py
torchvision/datasets/semeion.py
+2
-2
torchvision/datasets/stl10.py
torchvision/datasets/stl10.py
+4
-4
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+2
-2
torchvision/datasets/ucf101.py
torchvision/datasets/ucf101.py
+4
-4
torchvision/datasets/usps.py
torchvision/datasets/usps.py
+1
-1
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+4
-4
torchvision/datasets/video_utils.py
torchvision/datasets/video_utils.py
+4
-4
torchvision/datasets/vision.py
torchvision/datasets/vision.py
+6
-6
torchvision/datasets/voc.py
torchvision/datasets/voc.py
+1
-1
torchvision/datasets/widerface.py
torchvision/datasets/widerface.py
+5
-7
torchvision/extension.py
torchvision/extension.py
+3
-4
torchvision/io/_video_opt.py
torchvision/io/_video_opt.py
+2
-2
torchvision/io/image.py
torchvision/io/image.py
+1
-1
torchvision/io/video.py
torchvision/io/video.py
+1
-3
torchvision/models/_utils.py
torchvision/models/_utils.py
+1
-1
torchvision/models/alexnet.py
torchvision/models/alexnet.py
+1
-1
torchvision/models/densenet.py
torchvision/models/densenet.py
+4
-4
No files found.
torchvision/datasets/samplers/clip_sampler.py
View file @
d367a01a
...
@@ -54,7 +54,7 @@ class DistributedSampler(Sampler):
...
@@ -54,7 +54,7 @@ class DistributedSampler(Sampler):
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
assert
(
assert
(
len
(
dataset
)
%
group_size
==
0
len
(
dataset
)
%
group_size
==
0
),
"dataset length must be a multiplier of group size
"
"
dataset length: %d, group size: %d"
%
(
),
"dataset length must be a multiplier of group size
dataset length: %d, group size: %d"
%
(
len
(
dataset
),
len
(
dataset
),
group_size
,
group_size
,
)
)
...
@@ -117,7 +117,7 @@ class UniformClipSampler(Sampler):
...
@@ -117,7 +117,7 @@ class UniformClipSampler(Sampler):
def
__init__
(
self
,
video_clips
:
VideoClips
,
num_clips_per_video
:
int
)
->
None
:
def
__init__
(
self
,
video_clips
:
VideoClips
,
num_clips_per_video
:
int
)
->
None
:
if
not
isinstance
(
video_clips
,
VideoClips
):
if
not
isinstance
(
video_clips
,
VideoClips
):
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips,
"
"
got {
}"
.
format
(
type
(
video_clips
)
)
)
raise
TypeError
(
f
"Expected video_clips to be an instance of VideoClips, got
{
type
(
video_clips
)
}
"
)
self
.
video_clips
=
video_clips
self
.
video_clips
=
video_clips
self
.
num_clips_per_video
=
num_clips_per_video
self
.
num_clips_per_video
=
num_clips_per_video
...
@@ -151,7 +151,7 @@ class RandomClipSampler(Sampler):
...
@@ -151,7 +151,7 @@ class RandomClipSampler(Sampler):
def
__init__
(
self
,
video_clips
:
VideoClips
,
max_clips_per_video
:
int
)
->
None
:
def
__init__
(
self
,
video_clips
:
VideoClips
,
max_clips_per_video
:
int
)
->
None
:
if
not
isinstance
(
video_clips
,
VideoClips
):
if
not
isinstance
(
video_clips
,
VideoClips
):
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips,
"
"
got {
}"
.
format
(
type
(
video_clips
)
)
)
raise
TypeError
(
f
"Expected video_clips to be an instance of VideoClips, got
{
type
(
video_clips
)
}
"
)
self
.
video_clips
=
video_clips
self
.
video_clips
=
video_clips
self
.
max_clips_per_video
=
max_clips_per_video
self
.
max_clips_per_video
=
max_clips_per_video
...
...
torchvision/datasets/sbd.py
View file @
d367a01a
...
@@ -63,9 +63,9 @@ class SBDataset(VisionDataset):
...
@@ -63,9 +63,9 @@ class SBDataset(VisionDataset):
self
.
_loadmat
=
loadmat
self
.
_loadmat
=
loadmat
except
ImportError
:
except
ImportError
:
raise
RuntimeError
(
"Scipy is not found. This dataset needs to have scipy installed:
"
"
pip install scipy"
)
raise
RuntimeError
(
"Scipy is not found. This dataset needs to have scipy installed: pip install scipy"
)
super
(
SBDataset
,
self
).
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
self
.
image_set
=
verify_str_arg
(
image_set
,
"image_set"
,
(
"train"
,
"val"
,
"train_noval"
))
self
.
image_set
=
verify_str_arg
(
image_set
,
"image_set"
,
(
"train"
,
"val"
,
"train_noval"
))
self
.
mode
=
verify_str_arg
(
mode
,
"mode"
,
(
"segmentation"
,
"boundaries"
))
self
.
mode
=
verify_str_arg
(
mode
,
"mode"
,
(
"segmentation"
,
"boundaries"
))
self
.
num_classes
=
20
self
.
num_classes
=
20
...
@@ -83,11 +83,11 @@ class SBDataset(VisionDataset):
...
@@ -83,11 +83,11 @@ class SBDataset(VisionDataset):
download_url
(
self
.
voc_train_url
,
sbd_root
,
self
.
voc_split_filename
,
self
.
voc_split_md5
)
download_url
(
self
.
voc_train_url
,
sbd_root
,
self
.
voc_split_filename
,
self
.
voc_split_md5
)
if
not
os
.
path
.
isdir
(
sbd_root
):
if
not
os
.
path
.
isdir
(
sbd_root
):
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
split_f
=
os
.
path
.
join
(
sbd_root
,
image_set
.
rstrip
(
"
\n
"
)
+
".txt"
)
split_f
=
os
.
path
.
join
(
sbd_root
,
image_set
.
rstrip
(
"
\n
"
)
+
".txt"
)
with
open
(
os
.
path
.
join
(
split_f
)
,
"r"
)
as
fh
:
with
open
(
os
.
path
.
join
(
split_f
))
as
fh
:
file_names
=
[
x
.
strip
()
for
x
in
fh
.
readlines
()]
file_names
=
[
x
.
strip
()
for
x
in
fh
.
readlines
()]
self
.
images
=
[
os
.
path
.
join
(
image_dir
,
x
+
".jpg"
)
for
x
in
file_names
]
self
.
images
=
[
os
.
path
.
join
(
image_dir
,
x
+
".jpg"
)
for
x
in
file_names
]
...
...
torchvision/datasets/sbu.py
View file @
d367a01a
...
@@ -33,13 +33,13 @@ class SBU(VisionDataset):
...
@@ -33,13 +33,13 @@ class SBU(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
True
,
download
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
(
SBU
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
if
download
:
if
download
:
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
# Read the caption for each photo
# Read the caption for each photo
self
.
photos
=
[]
self
.
photos
=
[]
...
...
torchvision/datasets/semeion.py
View file @
d367a01a
...
@@ -35,13 +35,13 @@ class SEMEION(VisionDataset):
...
@@ -35,13 +35,13 @@ class SEMEION(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
True
,
download
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
(
SEMEION
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
if
download
:
if
download
:
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
fp
=
os
.
path
.
join
(
self
.
root
,
self
.
filename
)
fp
=
os
.
path
.
join
(
self
.
root
,
self
.
filename
)
data
=
np
.
loadtxt
(
fp
)
data
=
np
.
loadtxt
(
fp
)
...
...
torchvision/datasets/stl10.py
View file @
d367a01a
...
@@ -53,14 +53,14 @@ class STL10(VisionDataset):
...
@@ -53,14 +53,14 @@ class STL10(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
STL10
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
folds
=
self
.
_verify_folds
(
folds
)
self
.
folds
=
self
.
_verify_folds
(
folds
)
if
download
:
if
download
:
self
.
download
()
self
.
download
()
elif
not
self
.
_check_integrity
():
elif
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
# now load the picked numpy arrays
# now load the picked numpy arrays
self
.
labels
:
Optional
[
np
.
ndarray
]
self
.
labels
:
Optional
[
np
.
ndarray
]
...
@@ -92,7 +92,7 @@ class STL10(VisionDataset):
...
@@ -92,7 +92,7 @@ class STL10(VisionDataset):
elif
isinstance
(
folds
,
int
):
elif
isinstance
(
folds
,
int
):
if
folds
in
range
(
10
):
if
folds
in
range
(
10
):
return
folds
return
folds
msg
=
"Value for argument folds should be in the range [0, 10),
"
"
but got {}."
msg
=
"Value for argument folds should be in the range [0, 10), but got {}."
raise
ValueError
(
msg
.
format
(
folds
))
raise
ValueError
(
msg
.
format
(
folds
))
else
:
else
:
msg
=
"Expected type None or int for argument folds, but got type {}."
msg
=
"Expected type None or int for argument folds, but got type {}."
...
@@ -167,7 +167,7 @@ class STL10(VisionDataset):
...
@@ -167,7 +167,7 @@ class STL10(VisionDataset):
if
folds
is
None
:
if
folds
is
None
:
return
return
path_to_folds
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
folds_list_file
)
path_to_folds
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
folds_list_file
)
with
open
(
path_to_folds
,
"r"
)
as
f
:
with
open
(
path_to_folds
)
as
f
:
str_idx
=
f
.
read
().
splitlines
()[
folds
]
str_idx
=
f
.
read
().
splitlines
()[
folds
]
list_idx
=
np
.
fromstring
(
str_idx
,
dtype
=
np
.
int64
,
sep
=
" "
)
list_idx
=
np
.
fromstring
(
str_idx
,
dtype
=
np
.
int64
,
sep
=
" "
)
self
.
data
=
self
.
data
[
list_idx
,
:,
:,
:]
self
.
data
=
self
.
data
[
list_idx
,
:,
:,
:]
...
...
torchvision/datasets/svhn.py
View file @
d367a01a
...
@@ -60,7 +60,7 @@ class SVHN(VisionDataset):
...
@@ -60,7 +60,7 @@ class SVHN(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
SVHN
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
tuple
(
self
.
split_list
.
keys
()))
self
.
split
=
verify_str_arg
(
split
,
"split"
,
tuple
(
self
.
split_list
.
keys
()))
self
.
url
=
self
.
split_list
[
split
][
0
]
self
.
url
=
self
.
split_list
[
split
][
0
]
self
.
filename
=
self
.
split_list
[
split
][
1
]
self
.
filename
=
self
.
split_list
[
split
][
1
]
...
@@ -70,7 +70,7 @@ class SVHN(VisionDataset):
...
@@ -70,7 +70,7 @@ class SVHN(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
"Dataset not found or corrupted.
"
+
"
You can use download=True to download it"
)
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download it"
)
# import here rather than at top of file because this is
# import here rather than at top of file because this is
# an optional dependency for torchvision
# an optional dependency for torchvision
...
...
torchvision/datasets/ucf101.py
View file @
d367a01a
...
@@ -65,9 +65,9 @@ class UCF101(VisionDataset):
...
@@ -65,9 +65,9 @@ class UCF101(VisionDataset):
_video_min_dimension
:
int
=
0
,
_video_min_dimension
:
int
=
0
,
_audio_samples
:
int
=
0
,
_audio_samples
:
int
=
0
,
)
->
None
:
)
->
None
:
super
(
UCF101
,
self
).
__init__
(
root
)
super
().
__init__
(
root
)
if
not
1
<=
fold
<=
3
:
if
not
1
<=
fold
<=
3
:
raise
ValueError
(
"fold should be between 1 and 3, got {
}"
.
format
(
fold
)
)
raise
ValueError
(
f
"fold should be between 1 and 3, got
{
fold
}
"
)
extensions
=
(
"avi"
,)
extensions
=
(
"avi"
,)
self
.
fold
=
fold
self
.
fold
=
fold
...
@@ -102,10 +102,10 @@ class UCF101(VisionDataset):
...
@@ -102,10 +102,10 @@ class UCF101(VisionDataset):
def
_select_fold
(
self
,
video_list
:
List
[
str
],
annotation_path
:
str
,
fold
:
int
,
train
:
bool
)
->
List
[
int
]:
def
_select_fold
(
self
,
video_list
:
List
[
str
],
annotation_path
:
str
,
fold
:
int
,
train
:
bool
)
->
List
[
int
]:
name
=
"train"
if
train
else
"test"
name
=
"train"
if
train
else
"test"
name
=
"{}list{:02d}.txt"
.
format
(
name
,
fold
)
name
=
f
"
{
name
}
list
{
fold
:
02
d
}
.txt"
f
=
os
.
path
.
join
(
annotation_path
,
name
)
f
=
os
.
path
.
join
(
annotation_path
,
name
)
selected_files
=
set
()
selected_files
=
set
()
with
open
(
f
,
"r"
)
as
fid
:
with
open
(
f
)
as
fid
:
data
=
fid
.
readlines
()
data
=
fid
.
readlines
()
data
=
[
x
.
strip
().
split
(
" "
)[
0
]
for
x
in
data
]
data
=
[
x
.
strip
().
split
(
" "
)[
0
]
for
x
in
data
]
data
=
[
os
.
path
.
join
(
self
.
root
,
x
)
for
x
in
data
]
data
=
[
os
.
path
.
join
(
self
.
root
,
x
)
for
x
in
data
]
...
...
torchvision/datasets/usps.py
View file @
d367a01a
...
@@ -49,7 +49,7 @@ class USPS(VisionDataset):
...
@@ -49,7 +49,7 @@ class USPS(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
USPS
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
super
().
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
split
=
"train"
if
train
else
"test"
split
=
"train"
if
train
else
"test"
url
,
filename
,
checksum
=
self
.
split_list
[
split
]
url
,
filename
,
checksum
=
self
.
split_list
[
split
]
full_path
=
os
.
path
.
join
(
self
.
root
,
filename
)
full_path
=
os
.
path
.
join
(
self
.
root
,
filename
)
...
...
torchvision/datasets/utils.py
View file @
d367a01a
...
@@ -138,10 +138,10 @@ def download_url(
...
@@ -138,10 +138,10 @@ def download_url(
try
:
try
:
print
(
"Downloading "
+
url
+
" to "
+
fpath
)
print
(
"Downloading "
+
url
+
" to "
+
fpath
)
_urlretrieve
(
url
,
fpath
)
_urlretrieve
(
url
,
fpath
)
except
(
urllib
.
error
.
URLError
,
I
OError
)
as
e
:
# type: ignore[attr-defined]
except
(
urllib
.
error
.
URLError
,
O
S
Error
)
as
e
:
# type: ignore[attr-defined]
if
url
[:
5
]
==
"https"
:
if
url
[:
5
]
==
"https"
:
url
=
url
.
replace
(
"https:"
,
"http:"
)
url
=
url
.
replace
(
"https:"
,
"http:"
)
print
(
"Failed download. Trying https -> http instead.
"
"
Downloading "
+
url
+
" to "
+
fpath
)
print
(
"Failed download. Trying https -> http instead. Downloading "
+
url
+
" to "
+
fpath
)
_urlretrieve
(
url
,
fpath
)
_urlretrieve
(
url
,
fpath
)
else
:
else
:
raise
e
raise
e
...
@@ -428,7 +428,7 @@ def download_and_extract_archive(
...
@@ -428,7 +428,7 @@ def download_and_extract_archive(
download_url
(
url
,
download_root
,
filename
,
md5
)
download_url
(
url
,
download_root
,
filename
,
md5
)
archive
=
os
.
path
.
join
(
download_root
,
filename
)
archive
=
os
.
path
.
join
(
download_root
,
filename
)
print
(
"Extracting {
} to {}"
.
format
(
archive
,
extract_root
)
)
print
(
f
"Extracting
{
archive
}
to
{
extract_root
}
"
)
extract_archive
(
archive
,
extract_root
,
remove_finished
)
extract_archive
(
archive
,
extract_root
,
remove_finished
)
...
@@ -460,7 +460,7 @@ def verify_str_arg(
...
@@ -460,7 +460,7 @@ def verify_str_arg(
if
custom_msg
is
not
None
:
if
custom_msg
is
not
None
:
msg
=
custom_msg
msg
=
custom_msg
else
:
else
:
msg
=
"Unknown value '{value}' for argument {arg}.
"
"
Valid values are {{{valid_values}}}."
msg
=
"Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
msg
=
msg
.
format
(
value
=
value
,
arg
=
arg
,
valid_values
=
iterable_to_str
(
valid_values
))
msg
=
msg
.
format
(
value
=
value
,
arg
=
arg
,
valid_values
=
iterable_to_str
(
valid_values
))
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
...
...
torchvision/datasets/video_utils.py
View file @
d367a01a
...
@@ -46,7 +46,7 @@ def unfold(tensor, size, step, dilation=1):
...
@@ -46,7 +46,7 @@ def unfold(tensor, size, step, dilation=1):
return
torch
.
as_strided
(
tensor
,
new_size
,
new_stride
)
return
torch
.
as_strided
(
tensor
,
new_size
,
new_stride
)
class
_VideoTimestampsDataset
(
object
)
:
class
_VideoTimestampsDataset
:
"""
"""
Dataset used to parallelize the reading of the timestamps
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.
of a list of videos, given their paths in the filesystem.
...
@@ -72,7 +72,7 @@ def _collate_fn(x):
...
@@ -72,7 +72,7 @@ def _collate_fn(x):
return
x
return
x
class
VideoClips
(
object
)
:
class
VideoClips
:
"""
"""
Given a list of video files, computes all consecutive subvideos of size
Given a list of video files, computes all consecutive subvideos of size
`clip_length_in_frames`, where the distance between each subvideo in the
`clip_length_in_frames`, where the distance between each subvideo in the
...
@@ -293,7 +293,7 @@ class VideoClips(object):
...
@@ -293,7 +293,7 @@ class VideoClips(object):
video_idx (int): index of the video in `video_paths`
video_idx (int): index of the video in `video_paths`
"""
"""
if
idx
>=
self
.
num_clips
():
if
idx
>=
self
.
num_clips
():
raise
IndexError
(
"Index {} out of range
"
"({
} number of clips)"
.
format
(
idx
,
self
.
num_clips
())
)
raise
IndexError
(
f
"Index
{
idx
}
out of range
(
{
self
.
num_clips
()
}
number of clips)"
)
video_idx
,
clip_idx
=
self
.
get_clip_location
(
idx
)
video_idx
,
clip_idx
=
self
.
get_clip_location
(
idx
)
video_path
=
self
.
video_paths
[
video_idx
]
video_path
=
self
.
video_paths
[
video_idx
]
clip_pts
=
self
.
clips
[
video_idx
][
clip_idx
]
clip_pts
=
self
.
clips
[
video_idx
][
clip_idx
]
...
@@ -359,7 +359,7 @@ class VideoClips(object):
...
@@ -359,7 +359,7 @@ class VideoClips(object):
resampling_idx
=
resampling_idx
-
resampling_idx
[
0
]
resampling_idx
=
resampling_idx
-
resampling_idx
[
0
]
video
=
video
[
resampling_idx
]
video
=
video
[
resampling_idx
]
info
[
"video_fps"
]
=
self
.
frame_rate
info
[
"video_fps"
]
=
self
.
frame_rate
assert
len
(
video
)
==
self
.
num_frames
,
"{
} x {}"
.
format
(
video
.
shape
,
self
.
num_frames
)
assert
len
(
video
)
==
self
.
num_frames
,
f
"
{
video
.
shape
}
x
{
self
.
num_frames
}
"
return
video
,
audio
,
info
,
video_idx
return
video
,
audio
,
info
,
video_idx
def
__getstate__
(
self
):
def
__getstate__
(
self
):
...
...
torchvision/datasets/vision.py
View file @
d367a01a
...
@@ -43,7 +43,7 @@ class VisionDataset(data.Dataset):
...
@@ -43,7 +43,7 @@ class VisionDataset(data.Dataset):
has_transforms
=
transforms
is
not
None
has_transforms
=
transforms
is
not
None
has_separate_transform
=
transform
is
not
None
or
target_transform
is
not
None
has_separate_transform
=
transform
is
not
None
or
target_transform
is
not
None
if
has_transforms
and
has_separate_transform
:
if
has_transforms
and
has_separate_transform
:
raise
ValueError
(
"Only transforms or transform/target_transform can
"
"
be passed as argument"
)
raise
ValueError
(
"Only transforms or transform/target_transform can be passed as argument"
)
# for backwards-compatibility
# for backwards-compatibility
self
.
transform
=
transform
self
.
transform
=
transform
...
@@ -68,9 +68,9 @@ class VisionDataset(data.Dataset):
...
@@ -68,9 +68,9 @@ class VisionDataset(data.Dataset):
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
head
=
"Dataset "
+
self
.
__class__
.
__name__
head
=
"Dataset "
+
self
.
__class__
.
__name__
body
=
[
"Number of datapoints: {
}"
.
format
(
self
.
__len__
()
)
]
body
=
[
f
"Number of datapoints:
{
self
.
__len__
()
}
"
]
if
self
.
root
is
not
None
:
if
self
.
root
is
not
None
:
body
.
append
(
"Root location: {
}"
.
format
(
self
.
root
)
)
body
.
append
(
f
"Root location:
{
self
.
root
}
"
)
body
+=
self
.
extra_repr
().
splitlines
()
body
+=
self
.
extra_repr
().
splitlines
()
if
hasattr
(
self
,
"transforms"
)
and
self
.
transforms
is
not
None
:
if
hasattr
(
self
,
"transforms"
)
and
self
.
transforms
is
not
None
:
body
+=
[
repr
(
self
.
transforms
)]
body
+=
[
repr
(
self
.
transforms
)]
...
@@ -79,13 +79,13 @@ class VisionDataset(data.Dataset):
...
@@ -79,13 +79,13 @@ class VisionDataset(data.Dataset):
def
_format_transform_repr
(
self
,
transform
:
Callable
,
head
:
str
)
->
List
[
str
]:
def
_format_transform_repr
(
self
,
transform
:
Callable
,
head
:
str
)
->
List
[
str
]:
lines
=
transform
.
__repr__
().
splitlines
()
lines
=
transform
.
__repr__
().
splitlines
()
return
[
"{
}{}"
.
format
(
head
,
lines
[
0
]
)
]
+
[
"{}{}"
.
format
(
" "
*
len
(
head
),
line
)
for
line
in
lines
[
1
:]]
return
[
f
"
{
head
}{
lines
[
0
]
}
"
]
+
[
"{}{}"
.
format
(
" "
*
len
(
head
),
line
)
for
line
in
lines
[
1
:]]
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
return
""
return
""
class
StandardTransform
(
object
)
:
class
StandardTransform
:
def
__init__
(
self
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
)
->
None
:
self
.
transform
=
transform
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
target_transform
=
target_transform
...
@@ -99,7 +99,7 @@ class StandardTransform(object):
...
@@ -99,7 +99,7 @@ class StandardTransform(object):
def
_format_transform_repr
(
self
,
transform
:
Callable
,
head
:
str
)
->
List
[
str
]:
def
_format_transform_repr
(
self
,
transform
:
Callable
,
head
:
str
)
->
List
[
str
]:
lines
=
transform
.
__repr__
().
splitlines
()
lines
=
transform
.
__repr__
().
splitlines
()
return
[
"{
}{}"
.
format
(
head
,
lines
[
0
]
)
]
+
[
"{}{}"
.
format
(
" "
*
len
(
head
),
line
)
for
line
in
lines
[
1
:]]
return
[
f
"
{
head
}{
lines
[
0
]
}
"
]
+
[
"{}{}"
.
format
(
" "
*
len
(
head
),
line
)
for
line
in
lines
[
1
:]]
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
body
=
[
self
.
__class__
.
__name__
]
body
=
[
self
.
__class__
.
__name__
]
...
...
torchvision/datasets/voc.py
View file @
d367a01a
...
@@ -114,7 +114,7 @@ class _VOCBase(VisionDataset):
...
@@ -114,7 +114,7 @@ class _VOCBase(VisionDataset):
splits_dir
=
os
.
path
.
join
(
voc_root
,
"ImageSets"
,
self
.
_SPLITS_DIR
)
splits_dir
=
os
.
path
.
join
(
voc_root
,
"ImageSets"
,
self
.
_SPLITS_DIR
)
split_f
=
os
.
path
.
join
(
splits_dir
,
image_set
.
rstrip
(
"
\n
"
)
+
".txt"
)
split_f
=
os
.
path
.
join
(
splits_dir
,
image_set
.
rstrip
(
"
\n
"
)
+
".txt"
)
with
open
(
os
.
path
.
join
(
split_f
)
,
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
split_f
))
as
f
:
file_names
=
[
x
.
strip
()
for
x
in
f
.
readlines
()]
file_names
=
[
x
.
strip
()
for
x
in
f
.
readlines
()]
image_dir
=
os
.
path
.
join
(
voc_root
,
"JPEGImages"
)
image_dir
=
os
.
path
.
join
(
voc_root
,
"JPEGImages"
)
...
...
torchvision/datasets/widerface.py
View file @
d367a01a
...
@@ -62,7 +62,7 @@ class WIDERFace(VisionDataset):
...
@@ -62,7 +62,7 @@ class WIDERFace(VisionDataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
WIDERFace
,
self
).
__init__
(
super
().
__init__
(
root
=
os
.
path
.
join
(
root
,
self
.
BASE_FOLDER
),
transform
=
transform
,
target_transform
=
target_transform
root
=
os
.
path
.
join
(
root
,
self
.
BASE_FOLDER
),
transform
=
transform
,
target_transform
=
target_transform
)
)
# check arguments
# check arguments
...
@@ -72,9 +72,7 @@ class WIDERFace(VisionDataset):
...
@@ -72,9 +72,7 @@ class WIDERFace(VisionDataset):
self
.
download
()
self
.
download
()
if
not
self
.
_check_integrity
():
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
raise
RuntimeError
(
"Dataset not found or corrupted. You can use download=True to download and prepare it"
)
"Dataset not found or corrupted. "
+
"You can use download=True to download and prepare it"
)
self
.
img_info
:
List
[
Dict
[
str
,
Union
[
str
,
Dict
[
str
,
torch
.
Tensor
]]]]
=
[]
self
.
img_info
:
List
[
Dict
[
str
,
Union
[
str
,
Dict
[
str
,
torch
.
Tensor
]]]]
=
[]
if
self
.
split
in
(
"train"
,
"val"
):
if
self
.
split
in
(
"train"
,
"val"
):
...
@@ -115,7 +113,7 @@ class WIDERFace(VisionDataset):
...
@@ -115,7 +113,7 @@ class WIDERFace(VisionDataset):
filename
=
"wider_face_train_bbx_gt.txt"
if
self
.
split
==
"train"
else
"wider_face_val_bbx_gt.txt"
filename
=
"wider_face_train_bbx_gt.txt"
if
self
.
split
==
"train"
else
"wider_face_val_bbx_gt.txt"
filepath
=
os
.
path
.
join
(
self
.
root
,
"wider_face_split"
,
filename
)
filepath
=
os
.
path
.
join
(
self
.
root
,
"wider_face_split"
,
filename
)
with
open
(
filepath
,
"r"
)
as
f
:
with
open
(
filepath
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
file_name_line
,
num_boxes_line
,
box_annotation_line
=
True
,
False
,
False
file_name_line
,
num_boxes_line
,
box_annotation_line
=
True
,
False
,
False
num_boxes
,
box_counter
=
0
,
0
num_boxes
,
box_counter
=
0
,
0
...
@@ -157,12 +155,12 @@ class WIDERFace(VisionDataset):
...
@@ -157,12 +155,12 @@ class WIDERFace(VisionDataset):
box_counter
=
0
box_counter
=
0
labels
.
clear
()
labels
.
clear
()
else
:
else
:
raise
RuntimeError
(
"Error parsing annotation file {
}"
.
format
(
filepath
)
)
raise
RuntimeError
(
f
"Error parsing annotation file
{
filepath
}
"
)
def
parse_test_annotations_file
(
self
)
->
None
:
def
parse_test_annotations_file
(
self
)
->
None
:
filepath
=
os
.
path
.
join
(
self
.
root
,
"wider_face_split"
,
"wider_face_test_filelist.txt"
)
filepath
=
os
.
path
.
join
(
self
.
root
,
"wider_face_split"
,
"wider_face_test_filelist.txt"
)
filepath
=
abspath
(
expanduser
(
filepath
))
filepath
=
abspath
(
expanduser
(
filepath
))
with
open
(
filepath
,
"r"
)
as
f
:
with
open
(
filepath
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
for
line
in
lines
:
for
line
in
lines
:
line
=
line
.
rstrip
()
line
=
line
.
rstrip
()
...
...
torchvision/extension.py
View file @
d367a01a
...
@@ -60,10 +60,9 @@ def _check_cuda_version():
...
@@ -60,10 +60,9 @@ def _check_cuda_version():
if
t_major
!=
tv_major
or
t_minor
!=
tv_minor
:
if
t_major
!=
tv_major
or
t_minor
!=
tv_minor
:
raise
RuntimeError
(
raise
RuntimeError
(
"Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
f
"PyTorch has CUDA Version=
{
t_major
}
.
{
t_minor
}
and torchvision has "
"Please reinstall the torchvision that matches your PyTorch install."
.
format
(
f
"CUDA Version=
{
tv_major
}
.
{
tv_minor
}
. "
t_major
,
t_minor
,
tv_major
,
tv_minor
"Please reinstall the torchvision that matches your PyTorch install."
)
)
)
return
_version
return
_version
...
...
torchvision/io/_video_opt.py
View file @
d367a01a
...
@@ -20,7 +20,7 @@ default_timebase = Fraction(0, 1)
...
@@ -20,7 +20,7 @@ default_timebase = Fraction(0, 1)
# simple class for torch scripting
# simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable
# the complex Fraction class from fractions module is not scriptable
class
Timebase
(
object
)
:
class
Timebase
:
__annotations__
=
{
"numerator"
:
int
,
"denominator"
:
int
}
__annotations__
=
{
"numerator"
:
int
,
"denominator"
:
int
}
__slots__
=
[
"numerator"
,
"denominator"
]
__slots__
=
[
"numerator"
,
"denominator"
]
...
@@ -34,7 +34,7 @@ class Timebase(object):
...
@@ -34,7 +34,7 @@ class Timebase(object):
self
.
denominator
=
denominator
self
.
denominator
=
denominator
class
VideoMetaData
(
object
)
:
class
VideoMetaData
:
__annotations__
=
{
__annotations__
=
{
"has_video"
:
bool
,
"has_video"
:
bool
,
"video_timebase"
:
Timebase
,
"video_timebase"
:
Timebase
,
...
...
torchvision/io/image.py
View file @
d367a01a
...
@@ -161,7 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
...
@@ -161,7 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
JPEG file.
JPEG file.
"""
"""
if
quality
<
1
or
quality
>
100
:
if
quality
<
1
or
quality
>
100
:
raise
ValueError
(
"Image quality should be a positive number
"
"
between 1 and 100"
)
raise
ValueError
(
"Image quality should be a positive number between 1 and 100"
)
output
=
torch
.
ops
.
image
.
encode_jpeg
(
input
,
quality
)
output
=
torch
.
ops
.
image
.
encode_jpeg
(
input
,
quality
)
return
output
return
output
...
...
torchvision/io/video.py
View file @
d367a01a
...
@@ -271,9 +271,7 @@ def read_video(
...
@@ -271,9 +271,7 @@ def read_video(
end_pts
=
float
(
"inf"
)
end_pts
=
float
(
"inf"
)
if
end_pts
<
start_pts
:
if
end_pts
<
start_pts
:
raise
ValueError
(
raise
ValueError
(
f
"end_pts should be larger than start_pts, got start_pts=
{
start_pts
}
and end_pts=
{
end_pts
}
"
)
"end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}"
.
format
(
start_pts
,
end_pts
)
)
info
=
{}
info
=
{}
video_frames
=
[]
video_frames
=
[]
...
...
torchvision/models/_utils.py
View file @
d367a01a
...
@@ -54,7 +54,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
...
@@ -54,7 +54,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
if
not
return_layers
:
if
not
return_layers
:
break
break
super
(
IntermediateLayerGetter
,
self
).
__init__
(
layers
)
super
().
__init__
(
layers
)
self
.
return_layers
=
orig_return_layers
self
.
return_layers
=
orig_return_layers
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
torchvision/models/alexnet.py
View file @
d367a01a
...
@@ -17,7 +17,7 @@ model_urls = {
...
@@ -17,7 +17,7 @@ model_urls = {
class
AlexNet
(
nn
.
Module
):
class
AlexNet
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
:
int
=
1000
,
dropout
:
float
=
0.5
)
->
None
:
def
__init__
(
self
,
num_classes
:
int
=
1000
,
dropout
:
float
=
0.5
)
->
None
:
super
(
AlexNet
,
self
).
__init__
()
super
().
__init__
()
_log_api_usage_once
(
self
)
_log_api_usage_once
(
self
)
self
.
features
=
nn
.
Sequential
(
self
.
features
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
2
),
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
11
,
stride
=
4
,
padding
=
2
),
...
...
torchvision/models/densenet.py
View file @
d367a01a
...
@@ -26,7 +26,7 @@ class _DenseLayer(nn.Module):
...
@@ -26,7 +26,7 @@ class _DenseLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
num_input_features
:
int
,
growth_rate
:
int
,
bn_size
:
int
,
drop_rate
:
float
,
memory_efficient
:
bool
=
False
self
,
num_input_features
:
int
,
growth_rate
:
int
,
bn_size
:
int
,
drop_rate
:
float
,
memory_efficient
:
bool
=
False
)
->
None
:
)
->
None
:
super
(
_DenseLayer
,
self
).
__init__
()
super
().
__init__
()
self
.
norm1
:
nn
.
BatchNorm2d
self
.
norm1
:
nn
.
BatchNorm2d
self
.
add_module
(
"norm1"
,
nn
.
BatchNorm2d
(
num_input_features
))
self
.
add_module
(
"norm1"
,
nn
.
BatchNorm2d
(
num_input_features
))
self
.
relu1
:
nn
.
ReLU
self
.
relu1
:
nn
.
ReLU
...
@@ -107,7 +107,7 @@ class _DenseBlock(nn.ModuleDict):
...
@@ -107,7 +107,7 @@ class _DenseBlock(nn.ModuleDict):
drop_rate
:
float
,
drop_rate
:
float
,
memory_efficient
:
bool
=
False
,
memory_efficient
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
_DenseBlock
,
self
).
__init__
()
super
().
__init__
()
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
layer
=
_DenseLayer
(
layer
=
_DenseLayer
(
num_input_features
+
i
*
growth_rate
,
num_input_features
+
i
*
growth_rate
,
...
@@ -128,7 +128,7 @@ class _DenseBlock(nn.ModuleDict):
...
@@ -128,7 +128,7 @@ class _DenseBlock(nn.ModuleDict):
class
_Transition
(
nn
.
Sequential
):
class
_Transition
(
nn
.
Sequential
):
def
__init__
(
self
,
num_input_features
:
int
,
num_output_features
:
int
)
->
None
:
def
__init__
(
self
,
num_input_features
:
int
,
num_output_features
:
int
)
->
None
:
super
(
_Transition
,
self
).
__init__
()
super
().
__init__
()
self
.
add_module
(
"norm"
,
nn
.
BatchNorm2d
(
num_input_features
))
self
.
add_module
(
"norm"
,
nn
.
BatchNorm2d
(
num_input_features
))
self
.
add_module
(
"relu"
,
nn
.
ReLU
(
inplace
=
True
))
self
.
add_module
(
"relu"
,
nn
.
ReLU
(
inplace
=
True
))
self
.
add_module
(
"conv"
,
nn
.
Conv2d
(
num_input_features
,
num_output_features
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
))
self
.
add_module
(
"conv"
,
nn
.
Conv2d
(
num_input_features
,
num_output_features
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
))
...
@@ -162,7 +162,7 @@ class DenseNet(nn.Module):
...
@@ -162,7 +162,7 @@ class DenseNet(nn.Module):
memory_efficient
:
bool
=
False
,
memory_efficient
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
(
DenseNet
,
self
).
__init__
()
super
().
__init__
()
_log_api_usage_once
(
self
)
_log_api_usage_once
(
self
)
# First convolution
# First convolution
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment