Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2ba586d5
Unverified
Commit
2ba586d5
authored
Mar 18, 2024
by
Nicolas Hug
Committed by
GitHub
Mar 18, 2024
Browse files
Document that datasets support pathlib.Path (#8321)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
03251754
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
110 additions
and
100 deletions
+110
-100
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+11
-12
torchvision/datasets/_stereo_matching.py
torchvision/datasets/_stereo_matching.py
+21
-21
torchvision/datasets/caltech.py
torchvision/datasets/caltech.py
+4
-3
torchvision/datasets/celeba.py
torchvision/datasets/celeba.py
+3
-2
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+4
-3
torchvision/datasets/cityscapes.py
torchvision/datasets/cityscapes.py
+3
-2
torchvision/datasets/clevr.py
torchvision/datasets/clevr.py
+3
-3
torchvision/datasets/coco.py
torchvision/datasets/coco.py
+5
-4
torchvision/datasets/country211.py
torchvision/datasets/country211.py
+3
-3
torchvision/datasets/dtd.py
torchvision/datasets/dtd.py
+3
-3
torchvision/datasets/eurosat.py
torchvision/datasets/eurosat.py
+4
-3
torchvision/datasets/fer2013.py
torchvision/datasets/fer2013.py
+3
-3
torchvision/datasets/fgvc_aircraft.py
torchvision/datasets/fgvc_aircraft.py
+4
-3
torchvision/datasets/flickr.py
torchvision/datasets/flickr.py
+6
-5
torchvision/datasets/flowers102.py
torchvision/datasets/flowers102.py
+3
-3
torchvision/datasets/folder.py
torchvision/datasets/folder.py
+8
-7
torchvision/datasets/food101.py
torchvision/datasets/food101.py
+3
-3
torchvision/datasets/gtsrb.py
torchvision/datasets/gtsrb.py
+3
-3
torchvision/datasets/hmdb51.py
torchvision/datasets/hmdb51.py
+4
-3
torchvision/datasets/imagenet.py
torchvision/datasets/imagenet.py
+12
-11
No files found.
torchvision/datasets/_optical_flow.py
View file @
2ba586d5
...
@@ -13,7 +13,6 @@ from ..io.image import _read_png_16
...
@@ -13,7 +13,6 @@ from ..io.image import _read_png_16
from
.utils
import
_read_pfm
,
verify_str_arg
from
.utils
import
_read_pfm
,
verify_str_arg
from
.vision
import
VisionDataset
from
.vision
import
VisionDataset
T1
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
],
Optional
[
np
.
ndarray
]]
T1
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
],
Optional
[
np
.
ndarray
]]
T2
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
]]
T2
=
Tuple
[
Image
.
Image
,
Image
.
Image
,
Optional
[
np
.
ndarray
]]
...
@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
...
@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask
=
False
_has_builtin_flow_mask
=
False
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
)
super
().
__init__
(
root
=
root
)
self
.
transforms
=
transforms
self
.
transforms
=
transforms
...
@@ -113,7 +112,7 @@ class Sintel(FlowDataset):
...
@@ -113,7 +112,7 @@ class Sintel(FlowDataset):
...
...
Args:
Args:
root (str
ing
): Root directory of the Sintel Dataset.
root (str
or ``pathlib.Path``
): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
details on the different passes.
...
@@ -125,7 +124,7 @@ class Sintel(FlowDataset):
...
@@ -125,7 +124,7 @@ class Sintel(FlowDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
pass_name
:
str
=
"clean"
,
pass_name
:
str
=
"clean"
,
transforms
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
...
@@ -183,7 +182,7 @@ class KittiFlow(FlowDataset):
...
@@ -183,7 +182,7 @@ class KittiFlow(FlowDataset):
flow_occ
flow_occ
Args:
Args:
root (str
ing
): Root directory of the KittiFlow Dataset.
root (str
or ``pathlib.Path``
): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
...
@@ -191,7 +190,7 @@ class KittiFlow(FlowDataset):
...
@@ -191,7 +190,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask
=
True
_has_builtin_flow_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
@@ -248,7 +247,7 @@ class FlyingChairs(FlowDataset):
...
@@ -248,7 +247,7 @@ class FlyingChairs(FlowDataset):
Args:
Args:
root (str
ing
): Root directory of the FlyingChairs Dataset.
root (str
or ``pathlib.Path``
): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
...
@@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset):
...
@@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
"""
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"val"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"val"
))
...
@@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
...
@@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
TRAIN
TRAIN
Args:
Args:
root (str
ing
): Root directory of the intel FlyingThings3D Dataset.
root (str
or ``pathlib.Path``
): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
details on the different passes.
...
@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):
...
@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
pass_name
:
str
=
"clean"
,
pass_name
:
str
=
"clean"
,
camera
:
str
=
"left"
,
camera
:
str
=
"left"
,
...
@@ -411,7 +410,7 @@ class HD1K(FlowDataset):
...
@@ -411,7 +410,7 @@ class HD1K(FlowDataset):
image_2
image_2
Args:
Args:
root (str
ing
): Root directory of the HD1K Dataset.
root (str
or ``pathlib.Path``
): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
...
@@ -419,7 +418,7 @@ class HD1K(FlowDataset):
...
@@ -419,7 +418,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask
=
True
_has_builtin_flow_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
torchvision/datasets/_stereo_matching.py
View file @
2ba586d5
...
@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
...
@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask
=
False
_has_built_in_disparity_mask
=
False
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
"""
"""
Args:
Args:
root(str): Root directory of the dataset.
root(str): Root directory of the dataset.
...
@@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
...
@@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory where `carla-highres` is located.
root (str
or ``pathlib.Path``
): Root directory where `carla-highres` is located.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"carla-highres"
root
=
Path
(
root
)
/
"carla-highres"
...
@@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
...
@@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
calib
calib
Args:
Args:
root (str
ing
): Root directory where `Kitti2012` is located.
root (str
or ``pathlib.Path``
): Root directory where `Kitti2012` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
_has_built_in_disparity_mask
=
True
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
@@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
...
@@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
calib
calib
Args:
Args:
root (str
ing
): Root directory where `Kitti2015` is located.
root (str
or ``pathlib.Path``
): Root directory where `Kitti2015` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
_has_built_in_disparity_mask
=
True
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
@@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...
@@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory of the Middleburry 2014 Dataset.
root (str
or ``pathlib.Path``
): Root directory of the Middleburry 2014 Dataset.
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
...
@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...
@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
calibration
:
Optional
[
str
]
=
"perfect"
,
calibration
:
Optional
[
str
]
=
"perfect"
,
use_ambient_views
:
bool
=
False
,
use_ambient_views
:
bool
=
False
,
...
@@ -576,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...
@@ -576,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask
=
(
disparity_map
>
0
).
squeeze
(
0
)
# mask out invalid disparities
valid_mask
=
(
disparity_map
>
0
).
squeeze
(
0
)
# mask out invalid disparities
return
disparity_map
,
valid_mask
return
disparity_map
,
valid_mask
def
_download_dataset
(
self
,
root
:
str
)
->
None
:
def
_download_dataset
(
self
,
root
:
Union
[
str
,
Path
]
)
->
None
:
base_url
=
"https://vision.middlebury.edu/stereo/data/scenes2014/zip"
base_url
=
"https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
# train and additional splits have 2 different calibration settings
root
=
Path
(
root
)
/
"Middlebury2014"
root
=
Path
(
root
)
/
"Middlebury2014"
...
@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):
...
@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
transforms
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
...
@@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
...
@@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory where FallingThings is located.
root (str
or ``pathlib.Path``
): Root directory where FallingThings is located.
variant (string): Which variant to use. Either "single", "mixed", or "both".
variant (string): Which variant to use. Either "single", "mixed", or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
def
__init__
(
self
,
root
:
str
,
variant
:
str
=
"single"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
variant
:
str
=
"single"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"FallingThings"
root
=
Path
(
root
)
/
"FallingThings"
...
@@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
...
@@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory where SceneFlow is located.
root (str
or ``pathlib.Path``
): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
...
@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):
...
@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
variant
:
str
=
"FlyingThings3D"
,
variant
:
str
=
"FlyingThings3D"
,
pass_name
:
str
=
"clean"
,
pass_name
:
str
=
"clean"
,
transforms
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
...
@@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
...
@@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory where Sintel Stereo is located.
root (str
or ``pathlib.Path``
): Root directory where Sintel Stereo is located.
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
_has_built_in_disparity_mask
=
True
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
pass_name
:
str
=
"final"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
pass_name
:
str
=
"final"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"final"
,
"clean"
,
"both"
))
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"final"
,
"clean"
,
"both"
))
...
@@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
...
@@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory where InStereo2k is located.
root (str
or ``pathlib.Path``
): Root directory where InStereo2k is located.
split (string): Either "train" or "test".
split (string): Either "train" or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"InStereo2k"
/
split
root
=
Path
(
root
)
/
"InStereo2k"
/
split
...
@@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
...
@@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
...
...
Args:
Args:
root (str
ing
): Root directory of the ETH3D Dataset.
root (str
or ``pathlib.Path``
): Root directory of the ETH3D Dataset.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
"""
_has_built_in_disparity_mask
=
True
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
transforms
:
Optional
[
Callable
]
=
None
)
->
None
:
super
().
__init__
(
root
,
transforms
)
super
().
__init__
(
root
,
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
...
...
torchvision/datasets/caltech.py
View file @
2ba586d5
import
os
import
os
import
os.path
import
os.path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
from
PIL
import
Image
...
@@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
...
@@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args:
Args:
root (str
ing
): Root directory of dataset where directory
root (str
or ``pathlib.Path``
): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified
``annotation``. Can also be a list to output a tuple with all specified
...
@@ -38,7 +39,7 @@ class Caltech101(VisionDataset):
...
@@ -38,7 +39,7 @@ class Caltech101(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
target_type
:
Union
[
List
[
str
],
str
]
=
"category"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"category"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
@@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
...
@@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory of dataset where directory
root (str
or ``pathlib.Path``
): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
and returns a transformed version. E.g, ``transforms.RandomCrop``
...
...
torchvision/datasets/celeba.py
View file @
2ba586d5
import
csv
import
csv
import
os
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
PIL
import
PIL
...
@@ -16,7 +17,7 @@ class CelebA(VisionDataset):
...
@@ -16,7 +17,7 @@ class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory where images are downloaded to.
root (str
or ``pathlib.Path``
): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test', 'all'}.
split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
...
@@ -63,7 +64,7 @@ class CelebA(VisionDataset):
...
@@ -63,7 +64,7 @@ class CelebA(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"attr"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"attr"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/cifar.py
View file @
2ba586d5
import
os.path
import
os.path
import
pickle
import
pickle
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
...
@@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
...
@@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory of dataset where directory
root (str
or ``pathlib.Path``
): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
creates from test set.
...
@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):
...
@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
train
:
bool
=
True
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/cityscapes.py
View file @
2ba586d5
import
json
import
json
import
os
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
from
PIL
import
Image
...
@@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
...
@@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory of dataset where directory ``leftImg8bit``
root (str
or ``pathlib.Path``
): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
otherwise ``train``, ``train_extra`` or ``val``
...
@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):
...
@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
mode
:
str
=
"fine"
,
mode
:
str
=
"fine"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"instance"
,
target_type
:
Union
[
List
[
str
],
str
]
=
"instance"
,
...
...
torchvision/datasets/clevr.py
View file @
2ba586d5
import
json
import
json
import
pathlib
import
pathlib
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
from
PIL
import
Image
from
PIL
import
Image
...
@@ -15,7 +15,7 @@ class CLEVRClassification(VisionDataset):
...
@@ -15,7 +15,7 @@ class CLEVRClassification(VisionDataset):
The number of objects in a scene are used as label.
The number of objects in a scene are used as label.
Args:
Args:
root (str
ing
): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
root (str
or ``pathlib.Path``
): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True.
set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
...
@@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset):
...
@@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/coco.py
View file @
2ba586d5
import
os.path
import
os.path
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
from
PIL
import
Image
...
@@ -12,7 +13,7 @@ class CocoDetection(VisionDataset):
...
@@ -12,7 +13,7 @@ class CocoDetection(VisionDataset):
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args:
Args:
root (str
ing
): Root directory where images are downloaded to.
root (str
or ``pathlib.Path``
): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
and returns a transformed version. E.g, ``transforms.PILToTensor``
...
@@ -24,7 +25,7 @@ class CocoDetection(VisionDataset):
...
@@ -24,7 +25,7 @@ class CocoDetection(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
annFile
:
str
,
annFile
:
str
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
@@ -67,7 +68,7 @@ class CocoCaptions(CocoDetection):
...
@@ -67,7 +68,7 @@ class CocoCaptions(CocoDetection):
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args:
Args:
root (str
ing
): Root directory where images are downloaded to.
root (str
or ``pathlib.Path``
): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
and returns a transformed version. E.g, ``transforms.PILToTensor``
...
...
torchvision/datasets/country211.py
View file @
2ba586d5
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
,
Union
from
.folder
import
ImageFolder
from
.folder
import
ImageFolder
from
.utils
import
download_and_extract_archive
,
verify_str_arg
from
.utils
import
download_and_extract_archive
,
verify_str_arg
...
@@ -14,7 +14,7 @@ class Country211(ImageFolder):
...
@@ -14,7 +14,7 @@ class Country211(ImageFolder):
100 test images for each country.
100 test images for each country.
Args:
Args:
root (str
ing
): Root directory of the dataset.
root (str
or ``pathlib.Path``
): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
version. E.g, ``transforms.RandomCrop``.
...
@@ -28,7 +28,7 @@ class Country211(ImageFolder):
...
@@ -28,7 +28,7 @@ class Country211(ImageFolder):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/dtd.py
View file @
2ba586d5
import
os
import
os
import
pathlib
import
pathlib
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
...
@@ -12,7 +12,7 @@ class DTD(VisionDataset):
...
@@ -12,7 +12,7 @@ class DTD(VisionDataset):
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
Args:
Args:
root (str
ing
): Root directory of the dataset.
root (str
or ``pathlib.Path``
): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
...
@@ -34,7 +34,7 @@ class DTD(VisionDataset):
...
@@ -34,7 +34,7 @@ class DTD(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
partition
:
int
=
1
,
partition
:
int
=
1
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/eurosat.py
View file @
2ba586d5
import
os
import
os
from
typing
import
Callable
,
Optional
from
pathlib
import
Path
from
typing
import
Callable
,
Optional
,
Union
from
.folder
import
ImageFolder
from
.folder
import
ImageFolder
from
.utils
import
download_and_extract_archive
from
.utils
import
download_and_extract_archive
...
@@ -9,7 +10,7 @@ class EuroSAT(ImageFolder):
...
@@ -9,7 +10,7 @@ class EuroSAT(ImageFolder):
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory of dataset where ``root/eurosat`` exists.
root (str
or ``pathlib.Path``
): Root directory of dataset where ``root/eurosat`` exists.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target_transform (callable, optional): A function/transform that takes in the
...
@@ -21,7 +22,7 @@ class EuroSAT(ImageFolder):
...
@@ -21,7 +22,7 @@ class EuroSAT(ImageFolder):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
download
:
bool
=
False
,
...
...
torchvision/datasets/fer2013.py
View file @
2ba586d5
import
csv
import
csv
import
pathlib
import
pathlib
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
...
@@ -14,7 +14,7 @@ class FER2013(VisionDataset):
...
@@ -14,7 +14,7 @@ class FER2013(VisionDataset):
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory of dataset where directory
root (str
or ``pathlib.Path``
): Root directory of dataset where directory
``root/fer2013`` exists.
``root/fer2013`` exists.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
...
@@ -29,7 +29,7 @@ class FER2013(VisionDataset):
...
@@ -29,7 +29,7 @@ class FER2013(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/fgvc_aircraft.py
View file @
2ba586d5
from
__future__
import
annotations
from
__future__
import
annotations
import
os
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
...
@@ -23,7 +24,7 @@ class FGVCAircraft(VisionDataset):
...
@@ -23,7 +24,7 @@ class FGVCAircraft(VisionDataset):
- ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
- ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
Args:
Args:
root (str
ing
): Root directory of the FGVC Aircraft dataset.
root (str
or ``pathlib.Path``
): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
``trainval`` and ``test``.
annotation_level (str, optional): The annotation level, supports ``variant``,
annotation_level (str, optional): The annotation level, supports ``variant``,
...
@@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset):
...
@@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"trainval"
,
split
:
str
=
"trainval"
,
annotation_level
:
str
=
"variant"
,
annotation_level
:
str
=
"variant"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/flickr.py
View file @
2ba586d5
...
@@ -2,7 +2,8 @@ import glob
...
@@ -2,7 +2,8 @@ import glob
import
os
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
html.parser
import
HTMLParser
from
html.parser
import
HTMLParser
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
from
PIL
import
Image
...
@@ -12,7 +13,7 @@ from .vision import VisionDataset
...
@@ -12,7 +13,7 @@ from .vision import VisionDataset
class
Flickr8kParser
(
HTMLParser
):
class
Flickr8kParser
(
HTMLParser
):
"""Parser for extracting captions from the Flickr8k dataset web page."""
"""Parser for extracting captions from the Flickr8k dataset web page."""
def
__init__
(
self
,
root
:
str
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
root
=
root
self
.
root
=
root
...
@@ -56,7 +57,7 @@ class Flickr8k(VisionDataset):
...
@@ -56,7 +57,7 @@ class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory where images are downloaded to.
root (str
or ``pathlib.Path``
): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
and returns a transformed version. E.g, ``transforms.PILToTensor``
...
@@ -66,7 +67,7 @@ class Flickr8k(VisionDataset):
...
@@ -66,7 +67,7 @@ class Flickr8k(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
ann_file
:
str
,
ann_file
:
str
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
@@ -112,7 +113,7 @@ class Flickr30k(VisionDataset):
...
@@ -112,7 +113,7 @@ class Flickr30k(VisionDataset):
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory where images are downloaded to.
root (str
or ``pathlib.Path``
): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
and returns a transformed version. E.g, ``transforms.PILToTensor``
...
...
torchvision/datasets/flowers102.py
View file @
2ba586d5
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
...
@@ -22,7 +22,7 @@ class Flowers102(VisionDataset):
...
@@ -22,7 +22,7 @@ class Flowers102(VisionDataset):
have large variations within the category, and several very similar categories.
have large variations within the category, and several very similar categories.
Args:
Args:
root (str
ing
): Root directory of the dataset.
root (str
or ``pathlib.Path``
): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a
transform (callable, optional): A function/transform that takes in a PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``.
transformed version. E.g, ``transforms.RandomCrop``.
...
@@ -42,7 +42,7 @@ class Flowers102(VisionDataset):
...
@@ -42,7 +42,7 @@ class Flowers102(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/folder.py
View file @
2ba586d5
import
os
import
os
import
os.path
import
os.path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
PIL
import
Image
from
PIL
import
Image
...
@@ -32,7 +33,7 @@ def is_image_file(filename: str) -> bool:
...
@@ -32,7 +33,7 @@ def is_image_file(filename: str) -> bool:
return
has_file_allowed_extension
(
filename
,
IMG_EXTENSIONS
)
return
has_file_allowed_extension
(
filename
,
IMG_EXTENSIONS
)
def
find_classes
(
directory
:
str
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
def
find_classes
(
directory
:
Union
[
str
,
Path
]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
"""Finds the class folders in a dataset.
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
See :class:`DatasetFolder` for details.
...
@@ -46,7 +47,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
...
@@ -46,7 +47,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
def
make_dataset
(
def
make_dataset
(
directory
:
str
,
directory
:
Union
[
str
,
Path
]
,
class_to_idx
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
class_to_idx
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
extensions
:
Optional
[
Union
[
str
,
Tuple
[
str
,
...]]]
=
None
,
extensions
:
Optional
[
Union
[
str
,
Tuple
[
str
,
...]]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
...
@@ -112,7 +113,7 @@ class DatasetFolder(VisionDataset):
...
@@ -112,7 +113,7 @@ class DatasetFolder(VisionDataset):
:meth:`find_classes` method.
:meth:`find_classes` method.
Args:
Args:
root (str
ing
): Root directory path.
root (str
or ``pathlib.Path``
): Root directory path.
loader (callable): A function to load a sample given its path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
both extensions and is_valid_file should not be passed.
...
@@ -136,7 +137,7 @@ class DatasetFolder(VisionDataset):
...
@@ -136,7 +137,7 @@ class DatasetFolder(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
loader
:
Callable
[[
str
],
Any
],
loader
:
Callable
[[
str
],
Any
],
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
...
@@ -164,7 +165,7 @@ class DatasetFolder(VisionDataset):
...
@@ -164,7 +165,7 @@ class DatasetFolder(VisionDataset):
@
staticmethod
@
staticmethod
def
make_dataset
(
def
make_dataset
(
directory
:
str
,
directory
:
Union
[
str
,
Path
]
,
class_to_idx
:
Dict
[
str
,
int
],
class_to_idx
:
Dict
[
str
,
int
],
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
...
@@ -203,7 +204,7 @@ class DatasetFolder(VisionDataset):
...
@@ -203,7 +204,7 @@ class DatasetFolder(VisionDataset):
directory
,
class_to_idx
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
,
allow_empty
=
allow_empty
directory
,
class_to_idx
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
,
allow_empty
=
allow_empty
)
)
def
find_classes
(
self
,
directory
:
str
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
def
find_classes
(
self
,
directory
:
Union
[
str
,
Path
]
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
"""Find the class folders in a dataset structured as follows::
"""Find the class folders in a dataset structured as follows::
directory/
directory/
...
@@ -298,7 +299,7 @@ class ImageFolder(DatasetFolder):
...
@@ -298,7 +299,7 @@ class ImageFolder(DatasetFolder):
the same methods can be overridden to customize the dataset.
the same methods can be overridden to customize the dataset.
Args:
Args:
root (str
ing
): Root directory path.
root (str
or ``pathlib.Path``
): Root directory path.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target_transform (callable, optional): A function/transform that takes in the
...
...
torchvision/datasets/food101.py
View file @
2ba586d5
import
json
import
json
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
...
@@ -19,7 +19,7 @@ class Food101(VisionDataset):
...
@@ -19,7 +19,7 @@ class Food101(VisionDataset):
Args:
Args:
root (str
ing
): Root directory of the dataset.
root (str
or ``pathlib.Path``
): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
version. E.g, ``transforms.RandomCrop``.
...
@@ -34,7 +34,7 @@ class Food101(VisionDataset):
...
@@ -34,7 +34,7 @@ class Food101(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/gtsrb.py
View file @
2ba586d5
import
csv
import
csv
import
pathlib
import
pathlib
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
PIL
import
PIL
...
@@ -13,7 +13,7 @@ class GTSRB(VisionDataset):
...
@@ -13,7 +13,7 @@ class GTSRB(VisionDataset):
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
Args:
Args:
root (str
ing
): Root directory of the dataset.
root (str
or ``pathlib.Path``
): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
version. E.g, ``transforms.RandomCrop``.
...
@@ -25,7 +25,7 @@ class GTSRB(VisionDataset):
...
@@ -25,7 +25,7 @@ class GTSRB(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
pathlib
.
Path
]
,
split
:
str
=
"train"
,
split
:
str
=
"train"
,
transform
:
Optional
[
Callable
]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
...
...
torchvision/datasets/hmdb51.py
View file @
2ba586d5
import
glob
import
glob
import
os
import
os
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -28,7 +29,7 @@ class HMDB51(VisionDataset):
...
@@ -28,7 +29,7 @@ class HMDB51(VisionDataset):
Internally, it uses a VideoClips object to handle clip creation.
Internally, it uses a VideoClips object to handle clip creation.
Args:
Args:
root (str
ing
): Root directory of the HMDB51 Dataset.
root (str
or ``pathlib.Path``
): Root directory of the HMDB51 Dataset.
annotation_path (str): Path to the folder containing the split files.
annotation_path (str): Path to the folder containing the split files.
frames_per_clip (int): Number of frames in a clip.
frames_per_clip (int): Number of frames in a clip.
step_between_clips (int): Number of frames between each clip.
step_between_clips (int): Number of frames between each clip.
...
@@ -59,7 +60,7 @@ class HMDB51(VisionDataset):
...
@@ -59,7 +60,7 @@ class HMDB51(VisionDataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
annotation_path
:
str
,
annotation_path
:
str
,
frames_per_clip
:
int
,
frames_per_clip
:
int
,
step_between_clips
:
int
=
1
,
step_between_clips
:
int
=
1
,
...
...
torchvision/datasets/imagenet.py
View file @
2ba586d5
...
@@ -2,7 +2,8 @@ import os
...
@@ -2,7 +2,8 @@ import os
import
shutil
import
shutil
import
tempfile
import
tempfile
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -28,7 +29,7 @@ class ImageNet(ImageFolder):
...
@@ -28,7 +29,7 @@ class ImageNet(ImageFolder):
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
Args:
Args:
root (str
ing
): Root directory of the ImageNet Dataset.
root (str
or ``pathlib.Path``
): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in a PIL image
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
and returns a transformed version. E.g, ``transforms.RandomCrop``
...
@@ -45,7 +46,7 @@ class ImageNet(ImageFolder):
...
@@ -45,7 +46,7 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset
targets (list): The class_index value for each image in the dataset
"""
"""
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
**
kwargs
:
Any
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
split
:
str
=
"train"
,
**
kwargs
:
Any
)
->
None
:
root
=
self
.
root
=
os
.
path
.
expanduser
(
root
)
root
=
self
.
root
=
os
.
path
.
expanduser
(
root
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
(
"train"
,
"val"
))
self
.
split
=
verify_str_arg
(
split
,
"split"
,
(
"train"
,
"val"
))
...
@@ -78,7 +79,7 @@ class ImageNet(ImageFolder):
...
@@ -78,7 +79,7 @@ class ImageNet(ImageFolder):
return
"Split: {split}"
.
format
(
**
self
.
__dict__
)
return
"Split: {split}"
.
format
(
**
self
.
__dict__
)
def
load_meta_file
(
root
:
str
,
file
:
Optional
[
str
]
=
None
)
->
Tuple
[
Dict
[
str
,
str
],
List
[
str
]]:
def
load_meta_file
(
root
:
Union
[
str
,
Path
]
,
file
:
Optional
[
str
]
=
None
)
->
Tuple
[
Dict
[
str
,
str
],
List
[
str
]]:
if
file
is
None
:
if
file
is
None
:
file
=
META_FILE
file
=
META_FILE
file
=
os
.
path
.
join
(
root
,
file
)
file
=
os
.
path
.
join
(
root
,
file
)
...
@@ -93,7 +94,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
...
@@ -93,7 +94,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
raise
RuntimeError
(
msg
.
format
(
file
,
root
))
raise
RuntimeError
(
msg
.
format
(
file
,
root
))
def
_verify_archive
(
root
:
str
,
file
:
str
,
md5
:
str
)
->
None
:
def
_verify_archive
(
root
:
Union
[
str
,
Path
]
,
file
:
str
,
md5
:
str
)
->
None
:
if
not
check_integrity
(
os
.
path
.
join
(
root
,
file
),
md5
):
if
not
check_integrity
(
os
.
path
.
join
(
root
,
file
),
md5
):
msg
=
(
msg
=
(
"The archive {} is not present in the root directory or is corrupted. "
"The archive {} is not present in the root directory or is corrupted. "
...
@@ -102,12 +103,12 @@ def _verify_archive(root: str, file: str, md5: str) -> None:
...
@@ -102,12 +103,12 @@ def _verify_archive(root: str, file: str, md5: str) -> None:
raise
RuntimeError
(
msg
.
format
(
file
,
root
))
raise
RuntimeError
(
msg
.
format
(
file
,
root
))
def
parse_devkit_archive
(
root
:
str
,
file
:
Optional
[
str
]
=
None
)
->
None
:
def
parse_devkit_archive
(
root
:
Union
[
str
,
Path
]
,
file
:
Optional
[
str
]
=
None
)
->
None
:
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.
the meta information in a binary file.
Args:
Args:
root (str): Root directory containing the devkit archive
root (str
or ``pathlib.Path``
): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
'ILSVRC2012_devkit_t12.tar.gz'
"""
"""
...
@@ -156,12 +157,12 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
...
@@ -156,12 +157,12 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
torch
.
save
((
wnid_to_classes
,
val_wnids
),
os
.
path
.
join
(
root
,
META_FILE
))
torch
.
save
((
wnid_to_classes
,
val_wnids
),
os
.
path
.
join
(
root
,
META_FILE
))
def
parse_train_archive
(
root
:
str
,
file
:
Optional
[
str
]
=
None
,
folder
:
str
=
"train"
)
->
None
:
def
parse_train_archive
(
root
:
Union
[
str
,
Path
]
,
file
:
Optional
[
str
]
=
None
,
folder
:
str
=
"train"
)
->
None
:
"""Parse the train images archive of the ImageNet2012 classification dataset and
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.
prepare it for usage with the ImageNet dataset.
Args:
Args:
root (str): Root directory containing the train images archive
root (str
or ``pathlib.Path``
): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to
file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar'
'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to
folder (str, optional): Optional name for train images folder. Defaults to
...
@@ -183,13 +184,13 @@ def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "tr
...
@@ -183,13 +184,13 @@ def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "tr
def
parse_val_archive
(
def
parse_val_archive
(
root
:
str
,
file
:
Optional
[
str
]
=
None
,
wnids
:
Optional
[
List
[
str
]]
=
None
,
folder
:
str
=
"val"
root
:
Union
[
str
,
Path
]
,
file
:
Optional
[
str
]
=
None
,
wnids
:
Optional
[
List
[
str
]]
=
None
,
folder
:
str
=
"val"
)
->
None
:
)
->
None
:
"""Parse the validation images archive of the ImageNet2012 classification dataset
"""Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.
and prepare it for usage with the ImageNet dataset.
Args:
Args:
root (str): Root directory containing the validation images archive
root (str
or ``pathlib.Path``
): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to
file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar'
'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None
wnids (list, optional): List of WordNet IDs of the validation images. If None
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment