Unverified Commit 2ba586d5 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Document that datasets support pathlib.Path (#8321)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 03251754
......@@ -13,7 +13,6 @@ from ..io.image import _read_png_16
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
......@@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root=root)
self.transforms = transforms
......@@ -113,7 +112,7 @@ class Sintel(FlowDataset):
...
Args:
root (string): Root directory of the Sintel Dataset.
root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
......@@ -125,7 +124,7 @@ class Sintel(FlowDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
......@@ -183,7 +182,7 @@ class KittiFlow(FlowDataset):
flow_occ
Args:
root (string): Root directory of the KittiFlow Dataset.
root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
......@@ -191,7 +190,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -248,7 +247,7 @@ class FlyingChairs(FlowDataset):
Args:
root (string): Root directory of the FlyingChairs Dataset.
root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
......@@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val"))
......@@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
TRAIN
Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
......@@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
......@@ -411,7 +410,7 @@ class HD1K(FlowDataset):
image_2
Args:
root (string): Root directory of the HD1K Dataset.
root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
......@@ -419,7 +418,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......
......@@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
......@@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where `carla-highres` is located.
root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
root = Path(root) / "carla-highres"
......@@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
calib
Args:
root (string): Root directory where `Kitti2012` is located.
root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
calib
Args:
root (string): Root directory where `Kitti2015` is located.
root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...
Args:
root (string): Root directory of the Middleburry 2014 Dataset.
root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
......@@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
calibration: Optional[str] = "perfect",
use_ambient_views: bool = False,
......@@ -576,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask
def _download_dataset(self, root: str) -> None:
def _download_dataset(self, root: Union[str, Path]) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
......@@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms)
......@@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where FallingThings is located.
root (str or ``pathlib.Path``): Root directory where FallingThings is located.
variant (string): Which variant to use. Either "single", "mixed", or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
root = Path(root) / "FallingThings"
......@@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where SceneFlow is located.
root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
......@@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
......@@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory where Sintel Stereo is located.
root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
......@@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
...
Args:
root (string): Root directory where InStereo2k is located.
root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
split (string): Either "train" or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
root = Path(root) / "InStereo2k" / split
......@@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
...
Args:
root (string): Root directory of the ETH3D Dataset.
root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......
import os
import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
......@@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified
......@@ -38,7 +39,7 @@ class Caltech101(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......@@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
......
import csv
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
import PIL
......@@ -16,7 +17,7 @@ class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
......@@ -63,7 +64,7 @@ class CelebA(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
......
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import numpy as np
from PIL import Image
......@@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
......@@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
import json
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image
......@@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
......@@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
......
import json
import pathlib
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse
from PIL import Image
......@@ -15,7 +15,7 @@ class CLEVRClassification(VisionDataset):
The number of objects in a scene are used as label.
Args:
root (string): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
root (str or ``pathlib.Path``): Root directory of dataset where directory ``root/clevr`` exists or will be saved to if download is
set to True.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
......@@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
import os.path
from typing import Any, Callable, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
......@@ -12,7 +13,7 @@ class CocoDetection(VisionDataset):
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
......@@ -24,7 +25,7 @@ class CocoDetection(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......@@ -67,7 +68,7 @@ class CocoCaptions(CocoDetection):
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
......
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Optional, Union
from .folder import ImageFolder
from .utils import download_and_extract_archive, verify_str_arg
......@@ -14,7 +14,7 @@ class Country211(ImageFolder):
100 test images for each country.
Args:
root (string): Root directory of the dataset.
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
......@@ -28,7 +28,7 @@ class Country211(ImageFolder):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
import os
import pathlib
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image
......@@ -12,7 +12,7 @@ class DTD(VisionDataset):
"""`Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`_.
Args:
root (string): Root directory of the dataset.
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
......@@ -34,7 +34,7 @@ class DTD(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
partition: int = 1,
transform: Optional[Callable] = None,
......
import os
from typing import Callable, Optional
from pathlib import Path
from typing import Callable, Optional, Union
from .folder import ImageFolder
from .utils import download_and_extract_archive
......@@ -9,7 +10,7 @@ class EuroSAT(ImageFolder):
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
Args:
root (string): Root directory of dataset where ``root/eurosat`` exists.
root (str or ``pathlib.Path``): Root directory of dataset where ``root/eurosat`` exists.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
......@@ -21,7 +22,7 @@ class EuroSAT(ImageFolder):
def __init__(
self,
root: str,
root: Union[str, Path],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
......
import csv
import pathlib
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import torch
from PIL import Image
......@@ -14,7 +14,7 @@ class FER2013(VisionDataset):
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``root/fer2013`` exists.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
......@@ -29,7 +29,7 @@ class FER2013(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
from __future__ import annotations
import os
from typing import Any, Callable, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image
......@@ -23,7 +24,7 @@ class FGVCAircraft(VisionDataset):
- ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
Args:
root (string): Root directory of the FGVC Aircraft dataset.
root (str or ``pathlib.Path``): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
annotation_level (str, optional): The annotation level, supports ``variant``,
......@@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "trainval",
annotation_level: str = "variant",
transform: Optional[Callable] = None,
......
......@@ -2,7 +2,8 @@ import glob
import os
from collections import defaultdict
from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image
......@@ -12,7 +13,7 @@ from .vision import VisionDataset
class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root: str) -> None:
def __init__(self, root: Union[str, Path]) -> None:
super().__init__()
self.root = root
......@@ -56,7 +57,7 @@ class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
......@@ -66,7 +67,7 @@ class Flickr8k(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......@@ -112,7 +113,7 @@ class Flickr30k(VisionDataset):
"""`Flickr30k Entities <https://bryanplummer.com/Flickr30kEntities/>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.PILToTensor``
......
from pathlib import Path
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image
......@@ -22,7 +22,7 @@ class Flowers102(VisionDataset):
have large variations within the category, and several very similar categories.
Args:
root (string): Root directory of the dataset.
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``.
......@@ -42,7 +42,7 @@ class Flowers102(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
import os
import os.path
from pathlib import Path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from PIL import Image
......@@ -32,7 +33,7 @@ def is_image_file(filename: str) -> bool:
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
......@@ -46,7 +47,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
def make_dataset(
directory: str,
directory: Union[str, Path],
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
......@@ -112,7 +113,7 @@ class DatasetFolder(VisionDataset):
:meth:`find_classes` method.
Args:
root (string): Root directory path.
root (str or ``pathlib.Path``): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
......@@ -136,7 +137,7 @@ class DatasetFolder(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
......@@ -164,7 +165,7 @@ class DatasetFolder(VisionDataset):
@staticmethod
def make_dataset(
directory: str,
directory: Union[str, Path],
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
......@@ -203,7 +204,7 @@ class DatasetFolder(VisionDataset):
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty
)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
def find_classes(self, directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows::
directory/
......@@ -298,7 +299,7 @@ class ImageFolder(DatasetFolder):
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
root (str or ``pathlib.Path``): Root directory path.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
......
import json
from pathlib import Path
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import PIL.Image
......@@ -19,7 +19,7 @@ class Food101(VisionDataset):
Args:
root (string): Root directory of the dataset.
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
......@@ -34,7 +34,7 @@ class Food101(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
import csv
import pathlib
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union
import PIL
......@@ -13,7 +13,7 @@ class GTSRB(VisionDataset):
"""`German Traffic Sign Recognition Benchmark (GTSRB) <https://benchmark.ini.rub.de/>`_ Dataset.
Args:
root (string): Root directory of the dataset.
root (str or ``pathlib.Path``): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
......@@ -25,7 +25,7 @@ class GTSRB(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
......
import glob
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch import Tensor
......@@ -28,7 +29,7 @@ class HMDB51(VisionDataset):
Internally, it uses a VideoClips object to handle clip creation.
Args:
root (string): Root directory of the HMDB51 Dataset.
root (str or ``pathlib.Path``): Root directory of the HMDB51 Dataset.
annotation_path (str): Path to the folder containing the split files.
frames_per_clip (int): Number of frames in a clip.
step_between_clips (int): Number of frames between each clip.
......@@ -59,7 +60,7 @@ class HMDB51(VisionDataset):
def __init__(
self,
root: str,
root: Union[str, Path],
annotation_path: str,
frames_per_clip: int,
step_between_clips: int = 1,
......
......@@ -2,7 +2,8 @@ import os
import shutil
import tempfile
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
......@@ -28,7 +29,7 @@ class ImageNet(ImageFolder):
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
Args:
root (string): Root directory of the ImageNet Dataset.
root (str or ``pathlib.Path``): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
......@@ -45,7 +46,7 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root: str, split: str = "train", **kwargs: Any) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None:
root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val"))
......@@ -78,7 +79,7 @@ class ImageNet(ImageFolder):
return "Split: {split}".format(**self.__dict__)
def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
def load_meta_file(root: Union[str, Path], file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
if file is None:
file = META_FILE
file = os.path.join(root, file)
......@@ -93,7 +94,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
raise RuntimeError(msg.format(file, root))
def _verify_archive(root: str, file: str, md5: str) -> None:
def _verify_archive(root: Union[str, Path], file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5):
msg = (
"The archive {} is not present in the root directory or is corrupted. "
......@@ -102,12 +103,12 @@ def _verify_archive(root: str, file: str, md5: str) -> None:
raise RuntimeError(msg.format(file, root))
def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def parse_devkit_archive(root: Union[str, Path], file: Optional[str] = None) -> None:
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.
Args:
root (str): Root directory containing the devkit archive
root (str or ``pathlib.Path``): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
"""
......@@ -156,12 +157,12 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
def parse_train_archive(root: Union[str, Path], file: Optional[str] = None, folder: str = "train") -> None:
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.
Args:
root (str): Root directory containing the train images archive
root (str or ``pathlib.Path``): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to
......@@ -183,13 +184,13 @@ def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "tr
def parse_val_archive(
root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
root: Union[str, Path], file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
) -> None:
"""Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.
Args:
root (str): Root directory containing the validation images archive
root (str or ``pathlib.Path``): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment