Unverified Commit f9af70a9 authored by Alessandro Melis's avatar Alessandro Melis Committed by GitHub
Browse files

Make DatasetFolder.find_classes public (#3628)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 249e1a98
...@@ -182,7 +182,7 @@ class DatasetFolder(VisionDataset): ...@@ -182,7 +182,7 @@ class DatasetFolder(VisionDataset):
) -> None: ) -> None:
super(DatasetFolder, self).__init__(root, transform=transform, super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root) classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader self.loader = loader
...@@ -202,8 +202,12 @@ class DatasetFolder(VisionDataset): ...@@ -202,8 +202,12 @@ class DatasetFolder(VisionDataset):
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
@staticmethod def find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: """Same as :func:`find_classes`.
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
"""
return find_classes(dir) return find_classes(dir)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment