Unverified Commit 54d88f79 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Expose public methods in docs for datasets (#3732)

parent 5ac27fe3
...@@ -13,17 +13,11 @@ For example: :: ...@@ -13,17 +13,11 @@ For example: ::
shuffle=True, shuffle=True,
num_workers=args.nThreads) num_workers=args.nThreads)
The following datasets are available: .. currentmodule:: torchvision.datasets
.. contents:: Datasets
:local:
All the datasets have almost similar API. They all have two common arguments: All the datasets have almost similar API. They all have two common arguments:
``transform`` and ``target_transform`` to transform the input and target respectively. ``transform`` and ``target_transform`` to transform the input and target respectively.
You can also create your own datasets using the provided :ref:`base classes <base_classes_datasets>`.
.. currentmodule:: torchvision.datasets
Caltech Caltech
~~~~~~~ ~~~~~~~
...@@ -86,13 +80,6 @@ Detection ...@@ -86,13 +80,6 @@ Detection
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
DatasetFolder
~~~~~~~~~~~~~
.. autoclass:: DatasetFolder
:members: __getitem__
:special-members:
EMNIST EMNIST
~~~~~~ ~~~~~~
...@@ -127,13 +114,6 @@ HMDB51 ...@@ -127,13 +114,6 @@ HMDB51
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
ImageFolder
~~~~~~~~~~~
.. autoclass:: ImageFolder
:members: __getitem__
:special-members:
ImageNet ImageNet
~~~~~~~~~~~ ~~~~~~~~~~~
...@@ -263,3 +243,18 @@ WIDERFace ...@@ -263,3 +243,18 @@ WIDERFace
.. autoclass:: WIDERFace .. autoclass:: WIDERFace
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
.. _base_classes_datasets:
Base classes for custom datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: DatasetFolder
:members: __getitem__, find_classes, make_dataset
:special-members:
.. autoclass:: ImageFolder
:members: __getitem__
:special-members:
...@@ -33,30 +33,9 @@ def is_image_file(filename: str) -> bool: ...@@ -33,30 +33,9 @@ def is_image_file(filename: str) -> bool:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset structured as follows: """Finds the class folders in a dataset.
.. code::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
Args:
directory (str): Root directory path.
Raises: See :class:`DatasetFolder` for details.
FileNotFoundError: If ``directory`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
""" """
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes: if not classes:
...@@ -74,24 +53,10 @@ def make_dataset( ...@@ -74,24 +53,10 @@ def make_dataset(
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class). """Generates a list of samples of a form (path_to_sample, class).
Args: See :class:`DatasetFolder` for details.
directory (str): root dataset directory
class_to_idx (Optional[Dict[str, int]]): Dictionary mapping class name to class index. If omitted, is generated
by :func:`find_classes`.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns: Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
List[Tuple[str, int]]: samples of a form (path_to_sample, class) by default.
""" """
directory = os.path.expanduser(directory) directory = os.path.expanduser(directory)
...@@ -140,15 +105,10 @@ def make_dataset( ...@@ -140,15 +105,10 @@ def make_dataset(
class DatasetFolder(VisionDataset): class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: :: """A generic data loader.
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext This default directory structure can be customized by overriding the
root/class_y/nsdf3.ext :meth:`find_classes` method.
root/class_y/[...]/asd932_.ext
Args: Args:
root (string): Root directory path. root (string): Root directory path.
...@@ -200,6 +160,28 @@ class DatasetFolder(VisionDataset): ...@@ -200,6 +160,28 @@ class DatasetFolder(VisionDataset):
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
Args:
directory (str): root dataset directory, corresponding to ``self.root``.
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
if class_to_idx is None: if class_to_idx is None:
# prevent potential bug since make_dataset() would use the class_to_idx logic of the # prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which # find_classes() function, instead of using that of the find_classes() method, which
...@@ -209,13 +191,34 @@ class DatasetFolder(VisionDataset): ...@@ -209,13 +191,34 @@ class DatasetFolder(VisionDataset):
) )
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Same as :func:`find_classes`. """Find the class folders in a dataset structured as follows::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
This method can be overridden to only consider This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure. a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
""" """
return find_classes(dir) return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
...@@ -267,7 +270,7 @@ def default_loader(path: str) -> Any: ...@@ -267,7 +270,7 @@ def default_loader(path: str) -> Any:
class ImageFolder(DatasetFolder): class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: :: """A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png root/dog/xxx.png
root/dog/xxy.png root/dog/xxy.png
...@@ -277,6 +280,9 @@ class ImageFolder(DatasetFolder): ...@@ -277,6 +280,9 @@ class ImageFolder(DatasetFolder):
root/cat/nsdf3.png root/cat/nsdf3.png
root/cat/[...]/asd932_.png root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args: Args:
root (string): Root directory path. root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment