Commit f566fac8 authored by James Hamm's avatar James Hamm Committed by Soumith Chintala
Browse files

Replace endswith calls in a loop with a single endswith call (#832)

str.endswith can take a single string, or a tuple of strings. This replaces a loop over each extension and a call to endswith with a single call to endswith passing in all the allowed extensions. This has the advantage that the loop over each extension is done in c rather than python, and the code is a little less verbose.
parent ad0daff1
...@@ -12,13 +12,12 @@ def has_file_allowed_extension(filename, extensions): ...@@ -12,13 +12,12 @@ def has_file_allowed_extension(filename, extensions):
Args: Args:
filename (string): path to a file filename (string): path to a file
extensions (iterable of strings): extensions to consider (lowercase) extensions (tuple of strings): extensions to consider (lowercase)
Returns: Returns:
bool: True if the filename ends with one of given extensions bool: True if the filename ends with one of given extensions
""" """
filename_lower = filename.lower() return filename.lower().endswith(extensions)
return any(filename_lower.endswith(ext) for ext in extensions)
def is_image_file(filename): def is_image_file(filename):
...@@ -65,7 +64,7 @@ class DatasetFolder(data.Dataset): ...@@ -65,7 +64,7 @@ class DatasetFolder(data.Dataset):
Args: Args:
root (string): Root directory path. root (string): Root directory path.
loader (callable): A function to load a sample given its path. loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions. extensions (tuple[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version. a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images. E.g, ``transforms.RandomCrop`` for images.
...@@ -151,7 +150,7 @@ class DatasetFolder(data.Dataset): ...@@ -151,7 +150,7 @@ class DatasetFolder(data.Dataset):
return fmt_str return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp'] IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp')
def pil_loader(path): def pil_loader(path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment