Unverified Commit c334662b authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Removing warnings in the EEG-GNN example. (#5925)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 6f28e1ad
...@@ -7,12 +7,22 @@ import torch.nn as nn ...@@ -7,12 +7,22 @@ import torch.nn as nn
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from EEGGraphDataset import EEGGraphDataset from EEGGraphDataset import EEGGraphDataset
from joblib import load from joblib import dump, load
from sklearn import preprocessing from sklearn import preprocessing
from sklearn.metrics import balanced_accuracy_score, roc_auc_score from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from torch.utils.data import WeightedRandomSampler from torch.utils.data import WeightedRandomSampler
def _load_memory_mapped_array(file_name):
# Due to a legacy problem related to memory alignment in joblib [1], the
# data provided in the example may not be byte-aligned. This can be risky
# when loading with mmap_mode. To fix the issue, load and re-dump the data.
# [1] https://joblib.readthedocs.io/en/latest/developing.html#release-1-2-0
dump(load(file_name), file_name)
return load(file_name, mmap_mode="r")
if __name__ == "__main__": if __name__ == "__main__":
# argparse commandline args # argparse commandline args
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -77,7 +87,7 @@ if __name__ == "__main__": ...@@ -77,7 +87,7 @@ if __name__ == "__main__":
print(f" Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}") print(f" Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}")
# load patient level indices # load patient level indices
_DATASET_INDEX = pd.read_csv("master_metadata_index.csv") _DATASET_INDEX = pd.read_csv("master_metadata_index.csv", low_memory=False)
all_subjects = _DATASET_INDEX["patient_ID"].astype("str").unique() all_subjects = _DATASET_INDEX["patient_ID"].astype("str").unique()
print(f"Subject list fetched! Total subjects are {len(all_subjects)}.") print(f"Subject list fetched! Total subjects are {len(all_subjects)}.")
...@@ -89,10 +99,8 @@ if __name__ == "__main__": ...@@ -89,10 +99,8 @@ if __name__ == "__main__":
num_feats = args.num_feats num_feats = args.num_feats
# set up input and targets from files # set up input and targets from files
memmap_x = f"psd_features_data_X" x = _load_memory_mapped_array(f"psd_features_data_X")
memmap_y = f"labels_y" y = _load_memory_mapped_array(f"labels_y")
x = load(memmap_x, mmap_mode="r")
y = load(memmap_y, mmap_mode="r")
# normalize psd features data # normalize psd features data
normd_x = [] normd_x = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment