import cv2 import csv import numpy as np from torch.utils.data import Dataset from sklearn.model_selection import train_test_split import torch from torchvision import transforms class DataFile(): def __init__(self, path, local_rank): self.labels = [] self.pics = [] self.usage = [] self.local_rank = local_rank f = open(path, 'r') ln = 0 ts_proc = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) for row in csv.reader(f): if ln != 0: self.labels.append(int(row[0])) arr = row[1].split(' ') arr = [int(x) for x in arr] nparr = np.array(arr, dtype=np.uint8) rmk_img = cv2.resize(nparr, (224,224)) rmk_img = cv2.cvtColor(rmk_img, cv2.COLOR_GRAY2RGB) ft_ts = ts_proc(rmk_img) self.pics.append(ft_ts) self.usage.append(row[2]) ln += 1 if ln % 5000 == 0 and (local_rank == None or local_rank == 0): print("{} pics loaded.".format(ln)) f.close() def to_file(self): pass def get_data(self): return self.labels, self.pics, self.usage class LabelFile(): def __init__(self, path, local_rank): self.labels = [] f = open(path, 'r') ln = 0 for row in csv.reader(f, delimiter=','): if ln != 0: # print(row) lab_cells = row[2:] # print(lab_cells) lab_cells = np.array(lab_cells, dtype=np.uint8) lab = np.argmax(lab_cells) self.labels.append(lab) ln += 1 if ln % 5000 == 0 and (local_rank == None or local_rank == 0): print("{} labels loaded.".format(ln)) f.close() def get_labels(self): return self.labels class Fer2013Dataset(Dataset): def __init__(self, local_rank): print('local_rank_datawork:',local_rank) #self.datafile = DataFile('data/fer2013/fer2013.csv', local_rank) #self.labelfile = LabelFile('data/fer2013/fer2013new_ms_labs.csv', local_rank) self.datafile = DataFile('data/fer2013//DDP_data_231017.csv', local_rank) self.mode = 'train' self.X_train = None self.X_test = None self.y_train = None self.y_test = None if local_rank == None: self.randomization(0) else: self.randomization(local_rank) def randomization(self, seed): labels, pics, usage = self.datafile.get_data() # ms_labels = self.labelfile.get_labels() tarpics = [] tarlabs = [] for i in range(0, len(labels)): if labels[i] == 1: tarpics.append(pics[i]) tarlabs.append(0) if labels[i] == 2: tarpics.append(pics[i]) tarlabs.append(1) if labels[i] == 3: tarpics.append(pics[i]) tarlabs.append(2) self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( tarpics, tarlabs, test_size=0.2, random_state=0, stratify=tarlabs) def __len__(self): if self.mode == 'train': return len(self.y_train) elif self.mode == 'test': return len(self.y_test) def __getitem__(self, index): if self.mode == 'train': return self.X_train[index], torch.tensor(self.y_train[index]) elif self.mode == 'test': return self.X_test[index], torch.tensor(self.y_test[index]) def set_mode(self, mode): self.mode = mode def show_pic(pixels): cv2.imshow('Show', pixels) cv2.waitKey(0) cv2.destroyAllWindows() if __name__ == '__main__': # data = DataFile('data/fer2013/fer2013.csv') # labels, pics, usage = data.get_data() # for i in range(0, len(labels)): # show_pic(pics[i]) # labels = LabelFile('data/fer2013/fer2013new_ms_labs.csv', 0).get_labels() # print('done') x = Fer2013Dataset(0) print('done')