Unverified Commit 2ba586d5 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Document that datasets support pathlib.Path (#8321)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 03251754
...@@ -13,7 +13,6 @@ from ..io.image import _read_png_16 ...@@ -13,7 +13,6 @@ from ..io.image import _read_png_16
from .utils import _read_pfm, verify_str_arg from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]] T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]] T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
...@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset): ...@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False _has_builtin_flow_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root=root) super().__init__(root=root)
self.transforms = transforms self.transforms = transforms
...@@ -113,7 +112,7 @@ class Sintel(FlowDataset): ...@@ -113,7 +112,7 @@ class Sintel(FlowDataset):
... ...
Args: Args:
root (string): Root directory of the Sintel Dataset. root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test" split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes. details on the different passes.
...@@ -125,7 +124,7 @@ class Sintel(FlowDataset): ...@@ -125,7 +124,7 @@ class Sintel(FlowDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
pass_name: str = "clean", pass_name: str = "clean",
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
...@@ -183,7 +182,7 @@ class KittiFlow(FlowDataset): ...@@ -183,7 +182,7 @@ class KittiFlow(FlowDataset):
flow_occ flow_occ
Args: Args:
root (string): Root directory of the KittiFlow Dataset. root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test" split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
...@@ -191,7 +190,7 @@ class KittiFlow(FlowDataset): ...@@ -191,7 +190,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask = True _has_builtin_flow_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -248,7 +247,7 @@ class FlyingChairs(FlowDataset): ...@@ -248,7 +247,7 @@ class FlyingChairs(FlowDataset):
Args: Args:
root (string): Root directory of the FlyingChairs Dataset. root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val" split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
...@@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset): ...@@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val")) verify_str_arg(split, "split", valid_values=("train", "val"))
...@@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset): ...@@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
TRAIN TRAIN
Args: Args:
root (string): Root directory of the intel FlyingThings3D Dataset. root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test" split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes. details on the different passes.
...@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset): ...@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
pass_name: str = "clean", pass_name: str = "clean",
camera: str = "left", camera: str = "left",
...@@ -411,7 +410,7 @@ class HD1K(FlowDataset): ...@@ -411,7 +410,7 @@ class HD1K(FlowDataset):
image_2 image_2
Args: Args:
root (string): Root directory of the HD1K Dataset. root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test" split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
...@@ -419,7 +418,7 @@ class HD1K(FlowDataset): ...@@ -419,7 +418,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask = True _has_builtin_flow_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
......
...@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask = False _has_built_in_disparity_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
""" """
Args: Args:
root(str): Root directory of the dataset. root(str): Root directory of the dataset.
...@@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory where `carla-highres` is located. root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "carla-highres" root = Path(root) / "carla-highres"
...@@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
calib calib
Args: Args:
root (string): Root directory where `Kitti2012` is located. root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test". split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
calib calib
Args: Args:
root (string): Root directory where `Kitti2015` is located. root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test". split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory of the Middleburry 2014 Dataset. root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional" split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible. use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``. The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
...@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
calibration: Optional[str] = "perfect", calibration: Optional[str] = "perfect",
use_ambient_views: bool = False, use_ambient_views: bool = False,
...@@ -576,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -576,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask return disparity_map, valid_mask
def _download_dataset(self, root: str) -> None: def _download_dataset(self, root: Union[str, Path]) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings # train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014" root = Path(root) / "Middlebury2014"
...@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset): ...@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
) -> None: ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
...@@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory where FallingThings is located. root (str or ``pathlib.Path``): Root directory where FallingThings is located.
variant (string): Which variant to use. Either "single", "mixed", or "both". variant (string): Which variant to use. Either "single", "mixed", or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "FallingThings" root = Path(root) / "FallingThings"
...@@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory where SceneFlow is located. root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving". variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both". pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
...@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
variant: str = "FlyingThings3D", variant: str = "FlyingThings3D",
pass_name: str = "clean", pass_name: str = "clean",
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
...@@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset): ...@@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory where Sintel Stereo is located. root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
pass_name (string): The name of the pass to use, either "final", "clean" or "both". pass_name (string): The name of the pass to use, either "final", "clean" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
...@@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory where InStereo2k is located. root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
split (string): Either "train" or "test". split (string): Either "train" or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "InStereo2k" / split root = Path(root) / "InStereo2k" / split
...@@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
... ...
Args: Args:
root (string): Root directory of the ETH3D Dataset. root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
split (string, optional): The dataset split of scenes, either "train" (default) or "test". split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
......
import os import os
import os.path import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -16,7 +17,7 @@ class Caltech101(VisionDataset): ...@@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format. This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args: Args:
root (string): Root directory of dataset where directory root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True. ``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified ``annotation``. Can also be a list to output a tuple with all specified
...@@ -38,7 +39,7 @@ class Caltech101(VisionDataset): ...@@ -38,7 +39,7 @@ class Caltech101(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
target_type: Union[List[str], str] = "category", target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
...@@ -153,7 +154,7 @@ class Caltech256(VisionDataset): ...@@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset. """`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True. ``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
......
import csv import csv
import os import os
from collections import namedtuple from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import PIL import PIL
...@@ -16,7 +17,7 @@ class CelebA(VisionDataset): ...@@ -16,7 +17,7 @@ class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset. """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args: Args:
root (string): Root directory where images are downloaded to. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test', 'all'}. split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected. Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
...@@ -63,7 +64,7 @@ class CelebA(VisionDataset): ...@@ -63,7 +64,7 @@ class CelebA(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
target_type: Union[List[str], str] = "attr", target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
......
import os.path import os.path
import pickle import pickle
from typing import Any, Callable, Optional, Tuple from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -13,7 +14,7 @@ class CIFAR10(VisionDataset): ...@@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory root (str or ``pathlib.Path``): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True. ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set. creates from test set.
...@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset): ...@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
train: bool = True, train: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
import json import json
import os import os
from collections import namedtuple from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -13,7 +14,7 @@ class Cityscapes(VisionDataset): ...@@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory ``leftImg8bit`` root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located. and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine" split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val`` otherwise ``train``, ``train_extra`` or ``val``
...@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset): ...@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
mode: str = "fine", mode: str = "fine",
target_type: Union[List[str], str] = "instance", target_type: Union[List[str], str] = "instance",
......
import json import json
import pathlib import pathlib
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from PIL import Image from PIL import Image
...@@ -15,7 +15,7 @@ class CLEVRClassification(VisionDataset): ...@@ -15,7 +15,7 @@ class CLEVRClassification(VisionDataset):
The number of objects in a scene are used as label. The number of objects in a scene are used as label.
Args: Args:
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True. set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
...@@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset): ...@@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, pathlib.Path],
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
import os.path import os.path
from typing import Any, Callable, List, Optional, Tuple from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -12,7 +13,7 @@ class CocoDetection(VisionDataset): ...@@ -12,7 +13,7 @@ class CocoDetection(VisionDataset):
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_. It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args: Args:
root (string): Root directory where images are downloaded to. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
annFile (string): Path to json annotation file. annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
...@@ -24,7 +25,7 @@ class CocoDetection(VisionDataset): ...@@ -24,7 +25,7 @@ class CocoDetection(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
annFile: str, annFile: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
...@@ -67,7 +68,7 @@ class CocoCaptions(CocoDetection): ...@@ -67,7 +68,7 @@ class CocoCaptions(CocoDetection):
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_. It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args: Args:
root (string): Root directory where images are downloaded to. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
annFile (string): Path to json annotation file. annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
......
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional, Union
from .folder import ImageFolder from .folder import ImageFolder
from .utils import download_and_extract_archive, verify_str_arg from .utils import download_and_extract_archive, verify_str_arg
...@@ -14,7 +14,7 @@ class Country211(ImageFolder): ...@@ -14,7 +14,7 @@ class Country211(ImageFolder):
100 test images for each country. 100 test images for each country.
Args: Args:
root (string): Root directory of the dataset. root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
...@@ -28,7 +28,7 @@ class Country211(ImageFolder): ...@@ -28,7 +28,7 @@ class Country211(ImageFolder):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
import os import os
import pathlib import pathlib
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image import PIL.Image
...@@ -12,7 +12,7 @@ class DTD(VisionDataset): ...@@ -12,7 +12,7 @@ class DTD(VisionDataset):
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_. """`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
Args: Args:
root (string): Root directory of the dataset. root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
...@@ -34,7 +34,7 @@ class DTD(VisionDataset): ...@@ -34,7 +34,7 @@ class DTD(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, pathlib.Path],
split: str = "train", split: str = "train",
partition: int = 1, partition: int = 1,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
......
import os import os
from typing import Callable, Optional from pathlib import Path
from typing import Callable, Optional, Union
from .folder import ImageFolder from .folder import ImageFolder
from .utils import download_and_extract_archive from .utils import download_and_extract_archive
...@@ -9,7 +10,7 @@ class EuroSAT(ImageFolder): ...@@ -9,7 +10,7 @@ class EuroSAT(ImageFolder):
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset. """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where ``root/eurosat`` exists. root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
...@@ -21,7 +22,7 @@ class EuroSAT(ImageFolder): ...@@ -21,7 +22,7 @@ class EuroSAT(ImageFolder):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
......
import csv import csv
import pathlib import pathlib
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
from PIL import Image from PIL import Image
...@@ -14,7 +14,7 @@ class FER2013(VisionDataset): ...@@ -14,7 +14,7 @@ class FER2013(VisionDataset):
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset. <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory root (str or ``pathlib.Path``): Root directory of dataset where directory
``root/fer2013`` exists. ``root/fer2013`` exists.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
...@@ -29,7 +29,7 @@ class FER2013(VisionDataset): ...@@ -29,7 +29,7 @@ class FER2013(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, pathlib.Path],
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any, Callable, Optional, Tuple from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image import PIL.Image
...@@ -23,7 +24,7 @@ class FGVCAircraft(VisionDataset): ...@@ -23,7 +24,7 @@ class FGVCAircraft(VisionDataset):
- ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers. - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
Args: Args:
root (string): Root directory of the FGVC Aircraft dataset. root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``, split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``. ``trainval`` and ``test``.
annotation_level (str, optional): The annotation level, supports ``variant``, annotation_level (str, optional): The annotation level, supports ``variant``,
...@@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset): ...@@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "trainval", split: str = "trainval",
annotation_level: str = "variant", annotation_level: str = "variant",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
......
...@@ -2,7 +2,8 @@ import glob ...@@ -2,7 +2,8 @@ import glob
import os import os
from collections import defaultdict from collections import defaultdict
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -12,7 +13,7 @@ from .vision import VisionDataset ...@@ -12,7 +13,7 @@ from .vision import VisionDataset
class Flickr8kParser(HTMLParser): class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page.""" """Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root: str) -> None: def __init__(self, root: Union[str, Path]) -> None:
super().__init__() super().__init__()
self.root = root self.root = root
...@@ -56,7 +57,7 @@ class Flickr8k(VisionDataset): ...@@ -56,7 +57,7 @@ class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset. """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
Args: Args:
root (string): Root directory where images are downloaded to. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
ann_file (string): Path to annotation file. ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
...@@ -66,7 +67,7 @@ class Flickr8k(VisionDataset): ...@@ -66,7 +67,7 @@ class Flickr8k(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
ann_file: str, ann_file: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
...@@ -112,7 +113,7 @@ class Flickr30k(VisionDataset): ...@@ -112,7 +113,7 @@ class Flickr30k(VisionDataset):
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset. """`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
Args: Args:
root (string): Root directory where images are downloaded to. root (str or ``pathlib.Path``): Root directory where images are downloaded to.
ann_file (string): Path to annotation file. ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor`` and returns a transformed version. E.g, ``transforms.PILToTensor``
......
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image import PIL.Image
...@@ -22,7 +22,7 @@ class Flowers102(VisionDataset): ...@@ -22,7 +22,7 @@ class Flowers102(VisionDataset):
have large variations within the category, and several very similar categories. have large variations within the category, and several very similar categories.
Args: Args:
root (string): Root directory of the dataset. root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transform (callable, optional): A function/transform that takes in a PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``. transformed version. E.g, ``transforms.RandomCrop``.
...@@ -42,7 +42,7 @@ class Flowers102(VisionDataset): ...@@ -42,7 +42,7 @@ class Flowers102(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
import os import os
import os.path import os.path
from pathlib import Path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -32,7 +33,7 @@ def is_image_file(filename: str) -> bool: ...@@ -32,7 +33,7 @@ def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS) return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset. """Finds the class folders in a dataset.
See :class:`DatasetFolder` for details. See :class:`DatasetFolder` for details.
...@@ -46,7 +47,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: ...@@ -46,7 +47,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
def make_dataset( def make_dataset(
directory: str, directory: Union[str, Path],
class_to_idx: Optional[Dict[str, int]] = None, class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
...@@ -112,7 +113,7 @@ class DatasetFolder(VisionDataset): ...@@ -112,7 +113,7 @@ class DatasetFolder(VisionDataset):
:meth:`find_classes` method. :meth:`find_classes` method.
Args: Args:
root (string): Root directory path. root (str or ``pathlib.Path``): Root directory path.
loader (callable): A function to load a sample given its path. loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions. extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed. both extensions and is_valid_file should not be passed.
...@@ -136,7 +137,7 @@ class DatasetFolder(VisionDataset): ...@@ -136,7 +137,7 @@ class DatasetFolder(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
loader: Callable[[str], Any], loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
...@@ -164,7 +165,7 @@ class DatasetFolder(VisionDataset): ...@@ -164,7 +165,7 @@ class DatasetFolder(VisionDataset):
@staticmethod @staticmethod
def make_dataset( def make_dataset(
directory: str, directory: Union[str, Path],
class_to_idx: Dict[str, int], class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
...@@ -203,7 +204,7 @@ class DatasetFolder(VisionDataset): ...@@ -203,7 +204,7 @@ class DatasetFolder(VisionDataset):
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
) )
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: def find_classes(self, directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows:: """Find the class folders in a dataset structured as follows::
directory/ directory/
...@@ -298,7 +299,7 @@ class ImageFolder(DatasetFolder): ...@@ -298,7 +299,7 @@ class ImageFolder(DatasetFolder):
the same methods can be overridden to customize the dataset. the same methods can be overridden to customize the dataset.
Args: Args:
root (string): Root directory path. root (str or ``pathlib.Path``): Root directory path.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
......
import json import json
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image import PIL.Image
...@@ -19,7 +19,7 @@ class Food101(VisionDataset): ...@@ -19,7 +19,7 @@ class Food101(VisionDataset):
Args: Args:
root (string): Root directory of the dataset. root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
...@@ -34,7 +34,7 @@ class Food101(VisionDataset): ...@@ -34,7 +34,7 @@ class Food101(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
import csv import csv
import pathlib import pathlib
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, Union
import PIL import PIL
...@@ -13,7 +13,7 @@ class GTSRB(VisionDataset): ...@@ -13,7 +13,7 @@ class GTSRB(VisionDataset):
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset. """`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
Args: Args:
root (string): Root directory of the dataset. root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
...@@ -25,7 +25,7 @@ class GTSRB(VisionDataset): ...@@ -25,7 +25,7 @@ class GTSRB(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, pathlib.Path],
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
......
import glob import glob
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch import Tensor from torch import Tensor
...@@ -28,7 +29,7 @@ class HMDB51(VisionDataset): ...@@ -28,7 +29,7 @@ class HMDB51(VisionDataset):
Internally, it uses a VideoClips object to handle clip creation. Internally, it uses a VideoClips object to handle clip creation.
Args: Args:
root (string): Root directory of the HMDB51 Dataset. root (str or ``pathlib.Path``): Root directory of the HMDB51 Dataset.
annotation_path (str): Path to the folder containing the split files. annotation_path (str): Path to the folder containing the split files.
frames_per_clip (int): Number of frames in a clip. frames_per_clip (int): Number of frames in a clip.
step_between_clips (int): Number of frames between each clip. step_between_clips (int): Number of frames between each clip.
...@@ -59,7 +60,7 @@ class HMDB51(VisionDataset): ...@@ -59,7 +60,7 @@ class HMDB51(VisionDataset):
def __init__( def __init__(
self, self,
root: str, root: Union[str, Path],
annotation_path: str, annotation_path: str,
frames_per_clip: int, frames_per_clip: int,
step_between_clips: int = 1, step_between_clips: int = 1,
......
...@@ -2,7 +2,8 @@ import os ...@@ -2,7 +2,8 @@ import os
import shutil import shutil
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch import torch
...@@ -28,7 +29,7 @@ class ImageNet(ImageFolder): ...@@ -28,7 +29,7 @@ class ImageNet(ImageFolder):
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory. or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
Args: Args:
root (string): Root directory of the ImageNet Dataset. root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``. split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in a PIL image transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
...@@ -45,7 +46,7 @@ class ImageNet(ImageFolder): ...@@ -45,7 +46,7 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset targets (list): The class_index value for each image in the dataset
""" """
def __init__(self, root: str, split: str = "train", **kwargs: Any) -> None: def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None:
root = self.root = os.path.expanduser(root) root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val")) self.split = verify_str_arg(split, "split", ("train", "val"))
...@@ -78,7 +79,7 @@ class ImageNet(ImageFolder): ...@@ -78,7 +79,7 @@ class ImageNet(ImageFolder):
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]: def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
if file is None: if file is None:
file = META_FILE file = META_FILE
file = os.path.join(root, file) file = os.path.join(root, file)
...@@ -93,7 +94,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str ...@@ -93,7 +94,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
raise RuntimeError(msg.format(file, root)) raise RuntimeError(msg.format(file, root))
def _verify_archive(root: str, file: str, md5: str) -> None: def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5): if not check_integrity(os.path.join(root, file), md5):
msg = ( msg = (
"The archive {} is not present in the root directory or is corrupted. " "The archive {} is not present in the root directory or is corrupted. "
...@@ -102,12 +103,12 @@ def _verify_archive(root: str, file: str, md5: str) -> None: ...@@ -102,12 +103,12 @@ def _verify_archive(root: str, file: str, md5: str) -> None:
raise RuntimeError(msg.format(file, root)) raise RuntimeError(msg.format(file, root))
def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None:
"""Parse the devkit archive of the ImageNet2012 classification dataset and save """Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file. the meta information in a binary file.
Args: Args:
root (str): Root directory containing the devkit archive root (str or ``pathlib.Path``): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz' 'ILSVRC2012_devkit_t12.tar.gz'
""" """
...@@ -156,12 +157,12 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: ...@@ -156,12 +157,12 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None: def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None:
"""Parse the train images archive of the ImageNet2012 classification dataset and """Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset. prepare it for usage with the ImageNet dataset.
Args: Args:
root (str): Root directory containing the train images archive root (str or ``pathlib.Path``): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar' 'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to folder (str, optional): Optional name for train images folder. Defaults to
...@@ -183,13 +184,13 @@ def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "tr ...@@ -183,13 +184,13 @@ def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "tr
def parse_val_archive( def parse_val_archive(
root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val" root: Union[str, Path], file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
) -> None: ) -> None:
"""Parse the validation images archive of the ImageNet2012 classification dataset """Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset. and prepare it for usage with the ImageNet dataset.
Args: Args:
root (str): Root directory containing the validation images archive root (str or ``pathlib.Path``): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar' 'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None wnids (list, optional): List of WordNet IDs of the validation images. If None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment