Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
cc26cd81
Commit
cc26cd81
authored
Nov 27, 2023
by
panning
Browse files
merge v0.16.0
parents
f78f29f5
fbb4cc54
Changes
370
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
146 additions
and
60 deletions
+146
-60
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+1
-1
torchvision/datasets/country211.py
torchvision/datasets/country211.py
+1
-1
torchvision/datasets/dtd.py
torchvision/datasets/dtd.py
+2
-2
torchvision/datasets/fgvc_aircraft.py
torchvision/datasets/fgvc_aircraft.py
+1
-1
torchvision/datasets/flowers102.py
torchvision/datasets/flowers102.py
+1
-1
torchvision/datasets/food101.py
torchvision/datasets/food101.py
+1
-1
torchvision/datasets/hmdb51.py
torchvision/datasets/hmdb51.py
+1
-1
torchvision/datasets/imagenet.py
torchvision/datasets/imagenet.py
+6
-0
torchvision/datasets/lfw.py
torchvision/datasets/lfw.py
+11
-11
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+7
-8
torchvision/datasets/moving_mnist.py
torchvision/datasets/moving_mnist.py
+93
-0
torchvision/datasets/places365.py
torchvision/datasets/places365.py
+2
-2
torchvision/datasets/rendered_sst2.py
torchvision/datasets/rendered_sst2.py
+1
-1
torchvision/datasets/sbu.py
torchvision/datasets/sbu.py
+2
-7
torchvision/datasets/stl10.py
torchvision/datasets/stl10.py
+1
-1
torchvision/datasets/sun397.py
torchvision/datasets/sun397.py
+1
-1
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+1
-1
torchvision/datasets/ucf101.py
torchvision/datasets/ucf101.py
+1
-1
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+8
-15
torchvision/datasets/video_utils.py
torchvision/datasets/video_utils.py
+4
-4
No files found.
Too many changes to show.
To preserve performance only
370 of 370+
files are displayed.
Plain diff
Email patch
torchvision/datasets/cityscapes.py
View file @
cc26cd81
...
...
@@ -177,7 +177,7 @@ class Cityscapes(VisionDataset):
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
than one item. Otherwise
,
target is a json object if target_type="polygon", else the image segmentation.
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
"RGB"
)
...
...
torchvision/datasets/country211.py
View file @
cc26cd81
...
...
@@ -11,7 +11,7 @@ class Country211(ImageFolder):
This dataset was built by filtering the images from the YFCC100m dataset
that have GPS coordinate corresponding to a ISO-3166 country code. The
dataset is balanced by sampling 150 train images, 50 validation images, and
100 test images
images
for each country.
100 test images for each country.
Args:
root (string): Root directory of the dataset.
...
...
torchvision/datasets/dtd.py
View file @
cc26cd81
import
os
import
pathlib
from
typing
import
Callable
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
import
PIL.Image
...
...
@@ -76,7 +76,7 @@ class DTD(VisionDataset):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_image_files
)
def
__getitem__
(
self
,
idx
)
:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]
:
image_file
,
label
=
self
.
_image_files
[
idx
],
self
.
_labels
[
idx
]
image
=
PIL
.
Image
.
open
(
image_file
).
convert
(
"RGB"
)
...
...
torchvision/datasets/fgvc_aircraft.py
View file @
cc26cd81
...
...
@@ -90,7 +90,7 @@ class FGVCAircraft(VisionDataset):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_image_files
)
def
__getitem__
(
self
,
idx
)
->
Tuple
[
Any
,
Any
]:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]:
image_file
,
label
=
self
.
_image_files
[
idx
],
self
.
_labels
[
idx
]
image
=
PIL
.
Image
.
open
(
image_file
).
convert
(
"RGB"
)
...
...
torchvision/datasets/flowers102.py
View file @
cc26cd81
...
...
@@ -76,7 +76,7 @@ class Flowers102(VisionDataset):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_image_files
)
def
__getitem__
(
self
,
idx
)
->
Tuple
[
Any
,
Any
]:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]:
image_file
,
label
=
self
.
_image_files
[
idx
],
self
.
_labels
[
idx
]
image
=
PIL
.
Image
.
open
(
image_file
).
convert
(
"RGB"
)
...
...
torchvision/datasets/food101.py
View file @
cc26cd81
...
...
@@ -69,7 +69,7 @@ class Food101(VisionDataset):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_image_files
)
def
__getitem__
(
self
,
idx
)
->
Tuple
[
Any
,
Any
]:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]:
image_file
,
label
=
self
.
_image_files
[
idx
],
self
.
_labels
[
idx
]
image
=
PIL
.
Image
.
open
(
image_file
).
convert
(
"RGB"
)
...
...
torchvision/datasets/hmdb51.py
View file @
cc26cd81
...
...
@@ -102,7 +102,7 @@ class HMDB51(VisionDataset):
output_format
=
output_format
,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta
data of full version rather than the subset version of
# to return the metadata of full version rather than the subset version of
# video clips
self
.
full_video_clips
=
video_clips
self
.
fold
=
fold
...
...
torchvision/datasets/imagenet.py
View file @
cc26cd81
...
...
@@ -21,6 +21,12 @@ META_FILE = "meta.bin"
class
ImageNet
(
ImageFolder
):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
.. note::
Before using this class, it is required to download ImageNet 2012 dataset from
`here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
Args:
root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
...
...
torchvision/datasets/lfw.py
View file @
cc26cd81
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
...
...
@@ -38,7 +38,7 @@ class _LFW(VisionDataset):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
os
.
path
.
join
(
root
,
self
.
base_folder
),
transform
=
transform
,
target_transform
=
target_transform
)
self
.
image_set
=
verify_str_arg
(
image_set
.
lower
(),
"image_set"
,
self
.
file_dict
.
keys
())
...
...
@@ -62,7 +62,7 @@ class _LFW(VisionDataset):
img
=
Image
.
open
(
f
)
return
img
.
convert
(
"RGB"
)
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
st1
=
check_integrity
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
),
self
.
md5
)
st2
=
check_integrity
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
),
self
.
checksums
[
self
.
labels_file
])
if
not
st1
or
not
st2
:
...
...
@@ -71,7 +71,7 @@ class _LFW(VisionDataset):
return
check_integrity
(
os
.
path
.
join
(
self
.
root
,
self
.
names
),
self
.
checksums
[
self
.
names
])
return
True
def
download
(
self
):
def
download
(
self
)
->
None
:
if
self
.
_check_integrity
():
print
(
"Files already downloaded and verified"
)
return
...
...
@@ -81,13 +81,13 @@ class _LFW(VisionDataset):
if
self
.
view
==
"people"
:
download_url
(
f
"
{
self
.
download_url_prefix
}{
self
.
names
}
"
,
self
.
root
)
def
_get_path
(
self
,
identity
,
no
)
:
def
_get_path
(
self
,
identity
:
str
,
no
:
Union
[
int
,
str
])
->
str
:
return
os
.
path
.
join
(
self
.
images_dir
,
identity
,
f
"
{
identity
}
_
{
int
(
no
):
04
d
}
.jpg"
)
def
extra_repr
(
self
)
->
str
:
return
f
"Alignment:
{
self
.
image_set
}
\n
Split:
{
self
.
split
}
"
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
...
...
@@ -119,13 +119,13 @@ class LFWPeople(_LFW):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
root
,
split
,
image_set
,
"people"
,
transform
,
target_transform
,
download
)
self
.
class_to_idx
=
self
.
_get_classes
()
self
.
data
,
self
.
targets
=
self
.
_get_people
()
def
_get_people
(
self
):
def
_get_people
(
self
)
->
Tuple
[
List
[
str
],
List
[
int
]]
:
data
,
targets
=
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
...
...
@@ -143,7 +143,7 @@ class LFWPeople(_LFW):
return
data
,
targets
def
_get_classes
(
self
):
def
_get_classes
(
self
)
->
Dict
[
str
,
int
]
:
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
names
))
as
f
:
lines
=
f
.
readlines
()
names
=
[
line
.
strip
().
split
()[
0
]
for
line
in
lines
]
...
...
@@ -201,12 +201,12 @@ class LFWPairs(_LFW):
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
):
)
->
None
:
super
().
__init__
(
root
,
split
,
image_set
,
"pairs"
,
transform
,
target_transform
,
download
)
self
.
pair_names
,
self
.
data
,
self
.
targets
=
self
.
_get_pairs
(
self
.
images_dir
)
def
_get_pairs
(
self
,
images_dir
)
:
def
_get_pairs
(
self
,
images_dir
:
str
)
->
Tuple
[
List
[
Tuple
[
str
,
str
]],
List
[
Tuple
[
str
,
str
]],
List
[
int
]]
:
pair_names
,
data
,
targets
=
[],
[],
[]
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
labels_file
))
as
f
:
lines
=
f
.
readlines
()
...
...
torchvision/datasets/mnist.py
View file @
cc26cd81
...
...
@@ -12,7 +12,7 @@ import numpy as np
import
torch
from
PIL
import
Image
from
.utils
import
check_integrity
,
download_and_extract_archive
,
extract_archive
,
verify_str_arg
from
.utils
import
_flip_byte_order
,
check_integrity
,
download_and_extract_archive
,
extract_archive
,
verify_str_arg
from
.vision
import
VisionDataset
...
...
@@ -366,7 +366,7 @@ class QMNIST(MNIST):
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set o
t
the testing set. Default: True.
training set o
r
the testing set. Default: True.
"""
subsets
=
{
"train"
:
"train"
,
"test"
:
"test"
,
"test10k"
:
"test"
,
"test50k"
:
"test"
,
"nist"
:
"nist"
}
...
...
@@ -519,13 +519,12 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
torch_type
=
SN3_PASCALVINCENT_TYPEMAP
[
ty
]
s
=
[
get_int
(
data
[
4
*
(
i
+
1
)
:
4
*
(
i
+
2
)])
for
i
in
range
(
nd
)]
num_bytes_per_value
=
torch
.
iinfo
(
torch_type
).
bits
//
8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal
=
sys
.
byteorder
==
"little"
and
num_bytes_per_value
>
1
parsed
=
torch
.
frombuffer
(
bytearray
(
data
),
dtype
=
torch_type
,
offset
=
(
4
*
(
nd
+
1
)))
if
needs_byte_reversal
:
parsed
=
parsed
.
flip
(
0
)
# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
# that is little endian and the dtype has more than one byte, we need to flip them.
if
sys
.
byteorder
==
"little"
and
parsed
.
element_size
()
>
1
:
parsed
=
_flip_byte_order
(
parsed
)
assert
parsed
.
shape
[
0
]
==
np
.
prod
(
s
)
or
not
strict
return
parsed
.
view
(
*
s
)
...
...
torchvision/datasets/moving_mnist.py
0 → 100644
View file @
cc26cd81
import
os.path
from
typing
import
Callable
,
Optional
import
numpy
as
np
import
torch
from
torchvision.datasets.utils
import
download_url
,
verify_str_arg
from
torchvision.datasets.vision
import
VisionDataset
class
MovingMNIST
(
VisionDataset
):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
If ``split=None``, the full data is returned.
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
transform (callable, optional): A function/transform that takes in an torch Tensor
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_URL
=
"http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
def
__init__
(
self
,
root
:
str
,
split
:
Optional
[
str
]
=
None
,
split_ratio
:
int
=
10
,
download
:
bool
=
False
,
transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
().
__init__
(
root
,
transform
=
transform
)
self
.
_base_folder
=
os
.
path
.
join
(
self
.
root
,
self
.
__class__
.
__name__
)
self
.
_filename
=
self
.
_URL
.
split
(
"/"
)[
-
1
]
if
split
is
not
None
:
verify_str_arg
(
split
,
"split"
,
(
"train"
,
"test"
))
self
.
split
=
split
if
not
isinstance
(
split_ratio
,
int
):
raise
TypeError
(
f
"`split_ratio` should be an integer, but got
{
type
(
split_ratio
)
}
"
)
elif
not
(
1
<=
split_ratio
<=
19
):
raise
ValueError
(
f
"`split_ratio` should be `1 <= split_ratio <= 19`, but got
{
split_ratio
}
instead."
)
self
.
split_ratio
=
split_ratio
if
download
:
self
.
download
()
if
not
self
.
_check_exists
():
raise
RuntimeError
(
"Dataset not found. You can use download=True to download it."
)
data
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
self
.
_base_folder
,
self
.
_filename
)))
if
self
.
split
==
"train"
:
data
=
data
[:
self
.
split_ratio
]
elif
self
.
split
==
"test"
:
data
=
data
[
self
.
split_ratio
:]
self
.
data
=
data
.
transpose
(
0
,
1
).
unsqueeze
(
2
).
contiguous
()
def
__getitem__
(
self
,
idx
:
int
)
->
torch
.
Tensor
:
"""
Args:
index (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
data
=
self
.
data
[
idx
]
if
self
.
transform
is
not
None
:
data
=
self
.
transform
(
data
)
return
data
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
def
_check_exists
(
self
)
->
bool
:
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
_base_folder
,
self
.
_filename
))
def
download
(
self
)
->
None
:
if
self
.
_check_exists
():
return
download_url
(
url
=
self
.
_URL
,
root
=
self
.
_base_folder
,
filename
=
self
.
_filename
,
md5
=
"be083ec986bfe91a449d63653c411eb2"
,
)
torchvision/datasets/places365.py
View file @
cc26cd81
...
...
@@ -15,7 +15,7 @@ class Places365(VisionDataset):
root (string): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
``val``.
small (bool, optional): If ``True``, uses the small images, i.
e. resized to 256 x 256 pixels, instead of the
small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
downloaded archives are not downloaded again.
...
...
@@ -32,7 +32,7 @@ class Places365(VisionDataset):
targets (list): The class_index value for each image in the dataset
Raises:
RuntimeError: If ``download is False`` and the meta files, i.
e. the devkit, are not present or corrupted.
RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted.
"""
_SPLITS
=
(
"train-standard"
,
"train-challenge"
,
"val"
)
...
...
torchvision/datasets/rendered_sst2.py
View file @
cc26cd81
...
...
@@ -59,7 +59,7 @@ class RenderedSST2(VisionDataset):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_samples
)
def
__getitem__
(
self
,
idx
)
->
Tuple
[
Any
,
Any
]:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]:
image_file
,
label
=
self
.
_samples
[
idx
]
image
=
PIL
.
Image
.
open
(
image_file
).
convert
(
"RGB"
)
...
...
torchvision/datasets/sbu.py
View file @
cc26cd81
...
...
@@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Tuple
from
PIL
import
Image
from
.utils
import
check_integrity
,
download_url
from
.utils
import
check_integrity
,
download_and_extract_archive
,
download_url
from
.vision
import
VisionDataset
...
...
@@ -90,17 +90,12 @@ class SBU(VisionDataset):
def
download
(
self
)
->
None
:
"""Download and extract the tarball, and download each individual photo."""
import
tarfile
if
self
.
_check_integrity
():
print
(
"Files already downloaded and verified"
)
return
download_url
(
self
.
url
,
self
.
root
,
self
.
filename
,
self
.
md5_checksum
)
# Extract file
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
),
"r:gz"
)
as
tar
:
tar
.
extractall
(
path
=
self
.
root
)
download_and_extract_archive
(
self
.
url
,
self
.
root
,
self
.
root
,
self
.
filename
,
self
.
md5_checksum
)
# Download individual photos
with
open
(
os
.
path
.
join
(
self
.
root
,
"dataset"
,
"SBU_captioned_photo_dataset_urls.txt"
))
as
fh
:
...
...
torchvision/datasets/stl10.py
View file @
cc26cd81
...
...
@@ -15,7 +15,7 @@ class STL10(VisionDataset):
root (string): Root directory of dataset where directory
``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected.
Accordingly
,
dataset is selected.
folds (int, optional): One of {0-9} or None.
For training, loads one of the 10 pre-defined folds of 1k samples for the
standard evaluation procedure. If no value is passed, loads the 5k samples.
...
...
torchvision/datasets/sun397.py
View file @
cc26cd81
...
...
@@ -55,7 +55,7 @@ class SUN397(VisionDataset):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_image_files
)
def
__getitem__
(
self
,
idx
)
->
Tuple
[
Any
,
Any
]:
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
Any
,
Any
]:
image_file
,
label
=
self
.
_image_files
[
idx
],
self
.
_labels
[
idx
]
image
=
PIL
.
Image
.
open
(
image_file
).
convert
(
"RGB"
)
...
...
torchvision/datasets/svhn.py
View file @
cc26cd81
...
...
@@ -78,7 +78,7 @@ class SVHN(VisionDataset):
loaded_mat
=
sio
.
loadmat
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
))
self
.
data
=
loaded_mat
[
"X"
]
# loading from the .mat file gives an np
array of type np.uint8
# loading from the .mat file gives an np
.nd
array of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
...
...
torchvision/datasets/ucf101.py
View file @
cc26cd81
...
...
@@ -93,7 +93,7 @@ class UCF101(VisionDataset):
output_format
=
output_format
,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta
data of full version rather than the subset version of
# to return the metadata of full version rather than the subset version of
# video clips
self
.
full_video_clips
=
video_clips
self
.
indices
=
self
.
_select_fold
(
video_list
,
annotation_path
,
fold
,
train
)
...
...
torchvision/datasets/utils.py
View file @
cc26cd81
...
...
@@ -48,19 +48,6 @@ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
_save_response_content
(
iter
(
lambda
:
response
.
read
(
chunk_size
),
b
""
),
filename
,
length
=
response
.
length
)
def
gen_bar_updater
()
->
Callable
[[
int
,
int
,
int
],
None
]:
warnings
.
warn
(
"The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15."
)
pbar
=
tqdm
(
total
=
None
)
def
bar_update
(
count
,
block_size
,
total_size
):
if
pbar
.
total
is
None
and
total_size
:
pbar
.
total
=
total_size
progress_bytes
=
count
*
block_size
pbar
.
update
(
progress_bytes
-
pbar
.
n
)
return
bar_update
def
calculate_md5
(
fpath
:
str
,
chunk_size
:
int
=
1024
*
1024
)
->
str
:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
...
...
@@ -70,7 +57,7 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
else
:
md5
=
hashlib
.
md5
()
with
open
(
fpath
,
"rb"
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
chunk_size
)
,
b
""
)
:
while
chunk
:
=
f
.
read
(
chunk_size
):
md5
.
update
(
chunk
)
return
md5
.
hexdigest
()
...
...
@@ -464,7 +451,7 @@ def verify_str_arg(
valid_values
:
Optional
[
Iterable
[
T
]]
=
None
,
custom_msg
:
Optional
[
str
]
=
None
,
)
->
T
:
if
not
isinstance
(
value
,
torch
.
_six
.
string_classes
):
if
not
isinstance
(
value
,
str
):
if
arg
is
None
:
msg
=
"Expected type str, but got type {type}."
else
:
...
...
@@ -520,3 +507,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
data
=
np
.
flip
(
data
,
axis
=
1
)
# flip on h dimension
data
=
data
[:
slice_channels
,
:,
:]
return
data
.
astype
(
np
.
float32
)
def
_flip_byte_order
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
(
t
.
contiguous
().
view
(
torch
.
uint8
).
view
(
*
t
.
shape
,
t
.
element_size
()).
flip
(
-
1
).
view
(
*
t
.
shape
[:
-
1
],
-
1
).
view
(
t
.
dtype
)
)
torchvision/datasets/video_utils.py
View file @
cc26cd81
...
...
@@ -49,7 +49,7 @@ class _VideoTimestampsDataset:
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.
Used in VideoClips and defined at top level so it can be
Used in VideoClips and defined at top level
,
so it can be
pickled when forking.
"""
...
...
@@ -187,9 +187,9 @@ class VideoClips:
}
return
type
(
self
)(
video_paths
,
self
.
num_frames
,
self
.
step
,
self
.
frame_rate
,
clip_length_in_frames
=
self
.
num_frames
,
frames_between_clips
=
self
.
step
,
frame_rate
=
self
.
frame_rate
,
_precomputed_metadata
=
metadata
,
num_workers
=
self
.
num_workers
,
_video_width
=
self
.
_video_width
,
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment