"src/vscode:/vscode.git/clone" did not exist on "dc3e0ca59bf26ebcc9f12ed186bfe8fca86c3a1b"
Commit 5c39840c authored by Sergey Zagoruyko's avatar Sergey Zagoruyko Committed by Adam Paszke
Browse files

Allow to pass load function in ImageFolder (#20)

parent df557474
......@@ -33,8 +33,14 @@ def make_dataset(dir, class_to_idx):
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
......@@ -44,10 +50,11 @@ class ImageFolder(data.Dataset):
self.class_to_idx = class_to_idx
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
path, target = self.imgs[index]
img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = self.loader(os.path.join(self.root, path))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment