Commit 5c39840c authored by Sergey Zagoruyko's avatar Sergey Zagoruyko Committed by Adam Paszke
Browse files

Allow to pass load function in ImageFolder (#20)

parent df557474
...@@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx): ...@@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx):
return images return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset): class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root) classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx) imgs = make_dataset(root, class_to_idx)
...@@ -44,10 +50,11 @@ class ImageFolder(data.Dataset): ...@@ -44,10 +50,11 @@ class ImageFolder(data.Dataset):
self.class_to_idx = class_to_idx self.class_to_idx = class_to_idx
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.imgs[index] path, target = self.imgs[index]
img = Image.open(os.path.join(self.root, path)).convert('RGB') img = self.loader(os.path.join(self.root, path))
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if self.target_transform is not None: if self.target_transform is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment