##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ## Created by: Hang Zhang ## ECE Department, Rutgers University ## Email: zhang.hang@rutgers.edu ## Copyright (c) 2017 ## ## This source code is licensed under the MIT-style license found in the ## LICENSE file in the root directory of this source tree ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # refer to https://github.com/pytorch/vision/blob/master/torchvision/ import torch.utils.data as data import torchvision from PIL import Image import os import os.path IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def make_dataset(dir, class_to_idx): images = [] for target in os.listdir(dir): d = os.path.join(dir, target, 'images') if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) return images def default_loader(path): return Image.open(path).convert('RGB') class DatasetLoader(data.Dataset): def __init__(self, root, transform=None, target_transform=None, loader=default_loader): classes, class_to_idx = find_classes(root) imgs = make_dataset(root, class_to_idx) if len(imgs) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root \ + "\nSupported image extensions are: " + \ ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.classes = classes self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): path, target = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.imgs) def annotation_reader(root, class_to_idx): # read the tiny imagenet annotations.txt and returns the imgs and class file = open(os.path.join(root,'val_annotations.txt'), 'r') images = [] for line in file: sp = line.split('\t') path = os.path.join(root,'images',sp[0]) item = [path, class_to_idx[sp[1]]] images.append(item) return images class ValDatasetLoader(data.Dataset): def __init__(self, root, classes, class_to_idx, transform=None, target_transform=None, loader=default_loader): imgs = annotation_reader(root, class_to_idx) self.root = root self.imgs = imgs self.classes = classes self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): path, target = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.imgs)