"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "58cca6e94fa0633e04e164814b7db981dd78289f"
Unverified Commit bf7994b0 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

folder (#2523)

parent 40333c5a
......@@ -4,9 +4,10 @@ from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
def has_file_allowed_extension(filename, extensions):
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
......@@ -19,7 +20,7 @@ def has_file_allowed_extension(filename, extensions):
return filename.lower().endswith(extensions)
def is_image_file(filename):
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
......@@ -31,7 +32,12 @@ def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
......@@ -39,8 +45,9 @@ def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
......@@ -87,8 +94,15 @@ class DatasetFolder(VisionDataset):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions=None, transform=None,
target_transform=None, is_valid_file=None):
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
......@@ -107,7 +121,7 @@ class DatasetFolder(VisionDataset):
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
"""
Finds the class folders in a dataset.
......@@ -125,7 +139,7 @@ class DatasetFolder(VisionDataset):
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -142,21 +156,22 @@ class DatasetFolder(VisionDataset):
return sample, target
def __len__(self):
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path):
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
......@@ -165,7 +180,7 @@ def accimage_loader(path):
return pil_loader(path)
def default_loader(path):
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
......@@ -200,8 +215,14 @@ class ImageFolder(DatasetFolder):
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment