import torch import numpy as np import math import pandas as pd import dgl from dgl.data import DGLDataset from itertools import product class EEGGraphDataset(DGLDataset): """ Build graph, treat all nodes as the same type Parameters ---------- x: edge weights of 8-node complete graph There are 1 x 64 edges y: labels (diseased/healthy) num_nodes: the number of nodes of the graph. In our case, it is 8. indices: Patient level indices. They are used to generate edge weights. Output ------ a complete 8-node DGLGraph with node features and edge weights """ def __init__(self, x, y, num_nodes, indices): # CAUTION - x and labels are memory-mapped, used as if they are in RAM. self.x = x self.labels = y self.indices = indices self.num_nodes = num_nodes # NOTE: this order decides the node index, keep consistent! self.ch_names = [ "F7-F3", "F8-F4", "T7-C3", "T8-C4", "P7-P3", "P8-P4", "O1-P3", "O2-P4" ] # in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights # Note: "01" is for "P03", and "02" is for "P04." self.ref_names = [ "F5", "F6", "C5", "C6", "P5", "P6", "O1", "O2" ] # edge indices source to target - 2 x E = 2 x 64 # fully connected undirected graph so 8*8=64 edges self.node_ids = range(len(self.ch_names)) self.edge_index = torch.tensor([[a, b] for a, b in product(self.node_ids, self.node_ids)], dtype=torch.long).t().contiguous() # edge attributes - E x 1 # only the spatial distance between electrodes for now - standardize between 0 and 1 self.distances = self.get_sensor_distances() a = np.array(self.distances) self.distances = (a - np.min(a)) / (np.max(a) - np.min(a)) self.spec_coh_values = np.load("spec_coh_values.npy", allow_pickle=True) # sensor distances don't depend on window ID def get_sensor_distances(self): coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep='\t') num_edges = self.edge_index.shape[1] distances = [] for edge_idx in range(num_edges): sensor1_idx = self.edge_index[0, edge_idx] sensor2_idx = self.edge_index[1, edge_idx] dist = self.get_geodesic_distance(sensor1_idx, sensor2_idx, coords_1010) distances.append(dist) assert len(distances) == num_edges return distances def get_geodesic_distance(self, montage_sensor1_idx, montage_sensor2_idx, coords_1010): # get the reference sensor in the 10-10 system for the current montage pair in 10-20 system ref_sensor1 = self.ref_names[montage_sensor1_idx] ref_sensor2 = self.ref_names[montage_sensor2_idx] x1 = float(coords_1010[coords_1010.label == ref_sensor1]["x"]) y1 = float(coords_1010[coords_1010.label == ref_sensor1]["y"]) z1 = float(coords_1010[coords_1010.label == ref_sensor1]["z"]) x2 = float(coords_1010[coords_1010.label == ref_sensor2]["x"]) y2 = float(coords_1010[coords_1010.label == ref_sensor2]["y"]) z2 = float(coords_1010[coords_1010.label == ref_sensor2]["z"]) # https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere r = 1 # since coords are on unit sphere # rounding is for numerical stability, domain is [-1, 1] dist = r * math.acos(round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r ** 2), 2)) return dist # returns size of dataset = number of indices def __len__(self): return len(self.indices) # retrieve one sample from the dataset after applying all transforms def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() # map input idx (ranging from 0 to __len__() inside self.indices) # to an idx in the whole dataset (inside self.x) # assert idx < len(self.indices) idx = self.indices[idx] node_features = self.x[idx] node_features = torch.from_numpy(node_features.reshape(8, 6)) # spectral coherence between 2 montage channels! spec_coh_values = self.spec_coh_values[idx, :] # combine edge weights and spect coh values into one value/ one E x 1 tensor edge_weights = self.distances + spec_coh_values edge_weights = torch.tensor(edge_weights) # trucated to integer # create 8-node complete graph src = [[0 for i in range(self.num_nodes)] for j in range(self.num_nodes)] for i in range(len(src)): for j in range(len(src[i])): src[i][j] = i src = np.array(src).flatten() det = [[i for i in range(self.num_nodes)] for j in range(self.num_nodes)] det = np.array(det).flatten() u, v = (torch.tensor(src), torch.tensor(det)) g = dgl.graph((u, v)) # add node features and edge features g.ndata['x'] = node_features g.edata['edge_weights'] = edge_weights return g, torch.tensor(idx), torch.tensor(self.labels[idx])