Unverified Commit 738b75f4 authored by JOHNW02's avatar JOHNW02 Committed by GitHub
Browse files

[example] Create EEG-GCNN example. (#3186)

* Create EEG-GCNN example.

* Update README.md

* Remove gitignore file.

* Update README.md

* change 'datas' to 'datasets'.

* Change train.py to main.py

* Added an entry in the indexing page.

* State "simplified version"; change how to run.

* Fix bug in contact

* Remove paper link in reference.

* Create working branch

* Add normalization of x.

* Update paper link and tags

* Update paper link in readme

* Update readme; add patient level indices

* Update readme. Add comments to models

* Update README.md

* change to with; specify location for ch and el; move note

* fix bug for note

* Add args for models; clean code.

* delete = in readme

* Add reference for spec_coh_values
parent cd2cf606
......@@ -19,7 +19,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- Tags: efficiency, node classification, label propagation
## 2020
- <a name="eeg-gcnn"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html).
- Example code: [PyTorch](../examples/pytorch/eeg-gcnn)
- Tags: graph classification, eeg representation learning, brain activity, graph convolution, neurological disease classification, large dataset, edge weights, node features, fully-connected graph, graph neural network
- <a name="rect"></a> Wang et al. Network Embedding with Completely-imbalanced Labels. [Paper link](https://ieeexplore.ieee.org/document/8979355).
- Example code: [PyTorch](../examples/pytorch/rect)
- Tags: node classification, network embedding, completely-imbalanced labels
......
import torch
import numpy as np
import math
import pandas as pd
import dgl
from dgl.data import DGLDataset
from itertools import product
class EEGGraphDataset(DGLDataset):
""" Build graph, treat all nodes as the same type
Parameters
----------
x: edge weights of 8-node complete graph
There are 1 x 64 edges
y: labels (diseased/healthy)
num_nodes: the number of nodes of the graph. In our case, it is 8.
indices: Patient level indices. They are used to generate edge weights.
Output
------
a complete 8-node DGLGraph with node features and edge weights
"""
def __init__(self, x, y, num_nodes, indices):
# CAUTION - x and labels are memory-mapped, used as if they are in RAM.
self.x = x
self.labels = y
self.indices = indices
self.num_nodes = num_nodes
# NOTE: this order decides the node index, keep consistent!
self.ch_names = [
"F7-F3",
"F8-F4",
"T7-C3",
"T8-C4",
"P7-P3",
"P8-P4",
"O1-P3",
"O2-P4"
]
# in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights
# Note: "01" is for "P03", and "02" is for "P04."
self.ref_names = [
"F5",
"F6",
"C5",
"C6",
"P5",
"P6",
"O1",
"O2"
]
# edge indices source to target - 2 x E = 2 x 64
# fully connected undirected graph so 8*8=64 edges
self.node_ids = range(len(self.ch_names))
self.edge_index = torch.tensor([[a, b] for a, b in product(self.node_ids, self.node_ids)],
dtype=torch.long).t().contiguous()
# edge attributes - E x 1
# only the spatial distance between electrodes for now - standardize between 0 and 1
self.distances = self.get_sensor_distances()
a = np.array(self.distances)
self.distances = (a - np.min(a)) / (np.max(a) - np.min(a))
self.spec_coh_values = np.load("spec_coh_values.npy", allow_pickle=True)
# sensor distances don't depend on window ID
def get_sensor_distances(self):
coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep='\t')
num_edges = self.edge_index.shape[1]
distances = []
for edge_idx in range(num_edges):
sensor1_idx = self.edge_index[0, edge_idx]
sensor2_idx = self.edge_index[1, edge_idx]
dist = self.get_geodesic_distance(sensor1_idx, sensor2_idx, coords_1010)
distances.append(dist)
assert len(distances) == num_edges
return distances
def get_geodesic_distance(self, montage_sensor1_idx, montage_sensor2_idx, coords_1010):
# get the reference sensor in the 10-10 system for the current montage pair in 10-20 system
ref_sensor1 = self.ref_names[montage_sensor1_idx]
ref_sensor2 = self.ref_names[montage_sensor2_idx]
x1 = float(coords_1010[coords_1010.label == ref_sensor1]["x"])
y1 = float(coords_1010[coords_1010.label == ref_sensor1]["y"])
z1 = float(coords_1010[coords_1010.label == ref_sensor1]["z"])
x2 = float(coords_1010[coords_1010.label == ref_sensor2]["x"])
y2 = float(coords_1010[coords_1010.label == ref_sensor2]["y"])
z2 = float(coords_1010[coords_1010.label == ref_sensor2]["z"])
# https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere
r = 1 # since coords are on unit sphere
# rounding is for numerical stability, domain is [-1, 1]
dist = r * math.acos(round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r ** 2), 2))
return dist
# returns size of dataset = number of indices
def __len__(self):
return len(self.indices)
# retrieve one sample from the dataset after applying all transforms
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
# map input idx (ranging from 0 to __len__() inside self.indices)
# to an idx in the whole dataset (inside self.x)
# assert idx < len(self.indices)
idx = self.indices[idx]
node_features = self.x[idx]
node_features = torch.from_numpy(node_features.reshape(8, 6))
# spectral coherence between 2 montage channels!
spec_coh_values = self.spec_coh_values[idx, :]
# combine edge weights and spect coh values into one value/ one E x 1 tensor
edge_weights = self.distances + spec_coh_values
edge_weights = torch.tensor(edge_weights) # trucated to integer
# create 8-node complete graph
src = [[0 for i in range(self.num_nodes)] for j in range(self.num_nodes)]
for i in range(len(src)):
for j in range(len(src[i])):
src[i][j] = i
src = np.array(src).flatten()
det = [[i for i in range(self.num_nodes)] for j in range(self.num_nodes)]
det = np.array(det).flatten()
u, v = (torch.tensor(src), torch.tensor(det))
g = dgl.graph((u, v))
# add node features and edge features
g.ndata['x'] = node_features
g.edata['edge_weights'] = edge_weights
return g, torch.tensor(idx), torch.tensor(self.labels[idx])
# DGL Implementation of EEG-GCNN Paper
This example is a simplified version that presents how to utilize the original EEG-GCNN model proposed in the paper [EEG-GCNN](http://proceedings.mlr.press/v136/wagh20a.html), implemented with DGL library. The example removes cross validation and optimal decision boundary that are used in the original code. The performance stats are slightly different from what is present in the paper. The original code is [here](https://github.com/neerajwagh/eeg-gcnn).
## All References
- [ML4H Poster](https://drive.google.com/file/d/14nuAQKiIud3p6-c8r9WLV2tAvCyRwRev/view?usp=sharing) can be helpful for understanding data preprocessing, model, and performance of the project.
- The recording of presentation by the author Neeraj Wagh can be found on [slideslive](https://slideslive.com/38941020/eeggcnn-augmenting-electroencephalogrambased-neurological-disease-diagnosis-using-a-domainguided-graph-convolutional-neural-network?ref=account-folder-62123-folders).
- The slides used during the presentation can be found [here](https://drive.google.com/file/d/1dXT4QAUXKauf7CAkhrVyhR2PFUsNh4b8/view?usp=sharing).
- Raw Data can be found with these two links: [MPI LEMON](http://fcon_1000.projects.nitrc.org/indi/retro/MPI_LEMON.html) (no registration needed), [TUH EEG Abnormal Corpus](https://www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_abnormal/) ([needs registration](https://www.isip.piconepress.com/projects/tuh_eeg/html/request_access.php))
## Dependencies
- Python 3.8.1
- PyTorch 1.7.0
- DGL 0.6.1
- numpy 1.20.2
- Sklearn 0.24.2
- pandas 1.2.4
## Dataset
- Final Models, Pre-computed Features, Training Metadata can be downloaded through [FigShare](https://figshare.com/articles/software/EEG-GCNN_Supporting_Resources_for_Reproducibility/13251452).
- In ```EEGGraphDataset.py```, we specify the channels and electrodes and use precomputed spectral coherence values to compute the edge weights. To use this example in your own advantage, please specify your channels and electrodes in ```__init__``` function of ```EEGGraphDataset.py```.
- To generate spectral coherence values, please refer to [spectral_connectivity](https://mne.tools/stable/generated/mne.connectivity.spectral_connectivity.html) function in mne library. An example usage may take the following form:
```python
# ....loop over all windows in dataset....
# window data is 10-second preprocessed multi-channel timeseries (shape: n_channels x n_timepoints) containing all channels in ch_names
window_data = np.expand_dims(window_data, axis=0)
# ch_names are listed in EEGGraphDataset.py
for ch_idx, ch in enumerate(ch_names):
# number of channels is is len(ch_names), which is 8 in our case.
spec_coh_values, _, _, _, _ = mne.connectivity.spectral_connectivity(data=window_data, method='coh', indices=([ch_idx]*8, range(8)), sfreq=SAMPLING_FREQ,
fmin=1.0, fmax=40.0, faverage=True, verbose=False)
```
## How to Run
- First, download ```figshare_upload/master_metadata_index.csv```, ```figshare_upload/psd_features_data_X```, ```figshare_upload/labels_y```, ```figshare_upload/psd_shallow_eeg-gcnn/spec_coh_values.npy```, and ```figshare_upload/psd_shallow_eeg-gcnn/standard_1010.tsv.txt```. Put them in the repo. <br>
- You may download these files by running:
```python
wget https://ndownloader.figshare.com/files/25518170
```
- You will need to unzip the downloaded file.
- Then run:
```python
python main.py
```
- The default model used is ```shallow_EEGGraphConvNet.py```. To use ```deep_EEGGraphConvNet.py```, run:
```python
python main.py --model deep
```
- After the code executes, you will be able to see similar stats in performance section printed. The code will save the trained model from every epoch.
## Performance
| DGL | AUC | Bal. Accuracy |
|-------------------|-------------|---------------|
| Shallow EEG-GCNN | 0.832 | 0.750 |
| Deep EEG-GCNN | 0.830 | 0.736 |
Shallow_EEGGraphConvNet | AUC | Bal.Accuracy |
:-------------------------:|:-------------------------:|:---------------------:|
![shallow_loss](https://user-images.githubusercontent.com/53772888/128595442-d185bd74-5c5d-4118-a6b7-b89dd307d3aa.png) |![shallow_auc](https://user-images.githubusercontent.com/53772888/128595453-2f3b181a-bcb7-4da4-becd-7a7aa62083bc.png)|![shallow_bacc](https://user-images.githubusercontent.com/53772888/128595456-b293c888-bf8c-4f37-bd58-d01885da3832.png)
Deep_EEGGraphConvNet | AUC | Bal.Accuracy |
:-------------------------:|:-------------------------:|:---------------:|
![deep_loss](https://user-images.githubusercontent.com/53772888/128595458-e4a76591-11cf-405f-9c20-2d161e49c358.png)|![deep_auc](https://user-images.githubusercontent.com/53772888/128595462-7a7bfb67-4601-4e83-8764-d7c44bf979b5.png)|![deep_bacc](https://user-images.githubusercontent.com/53772888/128595467-1a0cd37d-0152-431b-a29b-a40bafb71be5.png)
### Contact
- Email to John(_wei33@illinois.edu_)
- You may also contact the authors:
- Neeraj: nwagh2@illinois.edu / [Website](http://neerajwagh.com/) / [Twitter](https://twitter.com/neeraj_wagh) / [Google Scholar](https://scholar.google.com/citations?hl=en&user=lCy5VsUAAAAJ)
- Yoga: varatha2@illinois.edu / [Website](https://sites.google.com/view/yoga-personal/home) / [Google Scholar](https://scholar.google.com/citations?user=XwL4dBgAAAAJ&hl=en)
### Citation
Wagh, N. & Varatharajah, Y.. (2020). EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. Proceedings of the Machine Learning for Health NeurIPS Workshop, in PMLR 136:367-378 Available from http://proceedings.mlr.press/v136/wagh20a.html.
import torch.nn as nn
import torch.nn.functional as function
from dgl.nn import GraphConv, SumPooling
from torch.nn import BatchNorm1d
class EEGGraphConvNet(nn.Module):
""" EEGGraph Convolution Net
Parameters
----------
num_feats: the number of features per node. In our case, it is 6.
"""
def __init__(self, num_feats):
super(EEGGraphConvNet, self).__init__()
self.conv1 = GraphConv(num_feats, 16)
self.conv2 = GraphConv(16, 32)
self.conv3 = GraphConv(32, 64)
self.conv4 = GraphConv(64, 50)
self.conv4_bn = BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.fc_block1 = nn.Linear(50, 30)
self.fc_block2 = nn.Linear(30, 10)
self.fc_block3 = nn.Linear(10, 2)
# Xavier initializations
self.fc_block1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
self.fc_block3.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
self.sumpool = SumPooling()
def forward(self, g, return_graph_embedding=False):
x = g.ndata['x']
edge_weight = g.edata['edge_weights']
x = self.conv1(g, x, edge_weight=edge_weight)
x = function.leaky_relu(x, negative_slope=0.01)
x = function.dropout(x, p=0.2, training=self.training)
x = self.conv2(g, x, edge_weight=edge_weight)
x = function.leaky_relu(x, negative_slope=0.01)
x = function.dropout(x, p=0.2, training=self.training)
x = self.conv3(g, x, edge_weight=edge_weight)
x = function.leaky_relu(x, negative_slope=0.01)
x = function.dropout(x, p=0.2, training=self.training)
x = self.conv4(g, x, edge_weight=edge_weight)
x = self.conv4_bn(x)
x = function.leaky_relu(x, negative_slope=0.01)
x = function.dropout(x, p=0.2, training=self.training)
# NOTE: this takes node-level features/"embeddings"
# and aggregates to graph-level - use for graph-level classification
out = self.sumpool(g, x)
if return_graph_embedding:
return out
out = function.leaky_relu(self.fc_block1(out), negative_slope=0.1)
out = function.dropout(out, p=0.2, training=self.training)
out = function.leaky_relu(self.fc_block2(out), negative_slope=0.1)
out = function.dropout(out, p=0.2, training=self.training)
out = self.fc_block3(out)
return out
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from joblib import load
from EEGGraphDataset import EEGGraphDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import roc_auc_score
from sklearn.metrics import balanced_accuracy_score
from sklearn import preprocessing
if __name__ == "__main__":
# argparse commandline args
parser = argparse.ArgumentParser(description='Execute training pipeline on a given train/val subjects')
parser.add_argument('--num_feats', type=int, default=6, help='Number of features per node for the graph')
parser.add_argument('--num_nodes', type=int, default=8, help='Number of nodes in the graph')
parser.add_argument('--gpu_idx', type=int, default=0,
help='index of GPU device that should be used for this run, defaults to 0.')
parser.add_argument('--num_epochs', type=int, default=40, help='Number of epochs used to train')
parser.add_argument('--exp_name', type=str, default='default', help='Name for the test.')
parser.add_argument('--batch_size', type=int, default=512, help='Batch Size. Default is 512.')
parser.add_argument('--model', type=str, default='shallow',
help='type shallow to use shallow_EEGGraphDataset; '
'type deep to use deep_EEGGraphDataset. Default is shallow')
args = parser.parse_args()
# choose model
if args.model == 'shallow':
from shallow_EEGGraphConvNet import EEGGraphConvNet
if args.model == 'deep':
from deep_EEGGraphConvNet import EEGGraphConvNet
# set the random seed so that we can reproduce the results
np.random.seed(42)
torch.manual_seed(42)
# use GPU when available
_GPU_IDX = args.gpu_idx
_DEVICE = torch.device(f'cuda:{_GPU_IDX}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(_DEVICE)
print(f' Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}')
# load patient level indices
_DATASET_INDEX = pd.read_csv("master_metadata_index.csv")
all_subjects = _DATASET_INDEX["patient_ID"].astype("str").unique()
print(f"Subject list fetched! Total subjects are {len(all_subjects)}.")
# retrieve inputs
num_nodes = args.num_nodes
_NUM_EPOCHS = args.num_epochs
_EXPERIMENT_NAME = args.exp_name
_BATCH_SIZE = args.batch_size
num_feats = args.num_feats
# set up input and targets from files
memmap_x = f'psd_features_data_X'
memmap_y = f'labels_y'
x = load(memmap_x, mmap_mode='r')
y = load(memmap_y, mmap_mode='r')
# normalize psd features data
normd_x = []
for i in range(len(y)):
arr = x[i, :]
arr = arr.reshape(1, -1)
arr2 = preprocessing.normalize(arr)
arr2 = arr2.reshape(48)
normd_x.append(arr2)
norm = np.array(normd_x)
x = norm.reshape(len(y), 48)
# map 0/1 to diseased/healthy
label_mapping, y = np.unique(y, return_inverse=True)
print(f"Unique labels 0/1 mapping: {label_mapping}")
# split the dataset to train and test. The ratio of test is 0.3.
train_and_val_subjects, heldout_subjects = train_test_split(all_subjects, test_size=0.3, random_state=42)
# split the dataset using patient indices
train_window_indices = _DATASET_INDEX.index[
_DATASET_INDEX["patient_ID"].astype("str").isin(train_and_val_subjects)].tolist()
heldout_test_window_indices = _DATASET_INDEX.index[
_DATASET_INDEX["patient_ID"].astype("str").isin(heldout_subjects)].tolist()
# define model, optimizer, scheduler
model = EEGGraphConvNet(num_feats)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i * 10 for i in range(1, 26)], gamma=0.1)
model = model.to(_DEVICE).double()
num_trainable_params = np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()])
# Dataloader========================================================================================================
# use WeightedRandomSampler to balance the training dataset
NUM_WORKERS = 4
labels_unique, counts = np.unique(y, return_counts=True)
class_weights = np.array([1.0 / x for x in counts])
# provide weights for samples in the training set only
sample_weights = class_weights[y[train_window_indices]]
# sampler needs to come up with training set size number of samples
weighted_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(train_window_indices), replacement=True
)
# train data loader
train_dataset = EEGGraphDataset(
x=x, y=y, num_nodes=num_nodes, indices=train_window_indices
)
train_loader = GraphDataLoader(
dataset=train_dataset, batch_size=_BATCH_SIZE,
sampler=weighted_sampler,
num_workers=NUM_WORKERS,
pin_memory=True
)
# this loader is used without weighted sampling, to evaluate metrics on full training set after each epoch
train_metrics_loader = GraphDataLoader(
dataset=train_dataset, batch_size=_BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True
)
# test data loader
test_dataset = EEGGraphDataset(
x=x, y=y, num_nodes=num_nodes, indices=heldout_test_window_indices
)
test_loader = GraphDataLoader(
dataset=test_dataset, batch_size=_BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True
)
auroc_train_history = []
auroc_test_history = []
balACC_train_history = []
balACC_test_history = []
loss_train_history = []
loss_test_history = []
# training=========================================================================================================
for epoch in range(_NUM_EPOCHS):
model.train()
train_loss = []
for batch_idx, batch in enumerate(train_loader):
# send batch to GPU
g, dataset_idx, y = batch
g_batch = g.to(device=_DEVICE, non_blocking=True)
y_batch = y.to(device=_DEVICE, non_blocking=True)
optimizer.zero_grad()
# forward pass
outputs = model(g_batch)
loss = loss_function(outputs, y_batch)
train_loss.append(loss.item())
# backward pass
loss.backward()
optimizer.step()
# update learning rate
scheduler.step()
# evaluate model after each epoch for train-metric data============================================================
model.eval()
with torch.no_grad():
y_probs_train = torch.empty(0, 2).to(_DEVICE)
y_true_train, y_pred_train = [], []
for i, batch in enumerate(train_metrics_loader):
g, dataset_idx, y = batch
g_batch = g.to(device=_DEVICE, non_blocking=True)
y_batch = y.to(device=_DEVICE, non_blocking=True)
# forward pass
outputs = model(g_batch)
_, predicted = torch.max(outputs.data, 1)
y_pred_train += predicted.cpu().numpy().tolist()
# concatenate along 0th dimension
y_probs_train = torch.cat((y_probs_train, outputs.data), 0)
y_true_train += y_batch.cpu().numpy().tolist()
# returning prob distribution over target classes, take softmax over the 1st dimension
y_probs_train = nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()
y_true_train = np.array(y_true_train)
# evaluate model after each epoch for validation data ==============================================================
y_probs_test = torch.empty(0, 2).to(_DEVICE)
y_true_test, minibatch_loss, y_pred_test = [], [], []
for i, batch in enumerate(test_loader):
g, dataset_idx, y = batch
g_batch = g.to(device=_DEVICE, non_blocking=True)
y_batch = y.to(device=_DEVICE, non_blocking=True)
# forward pass
outputs = model(g_batch)
_, predicted = torch.max(outputs.data, 1)
y_pred_test += predicted.cpu().numpy().tolist()
loss = loss_function(outputs, y_batch)
minibatch_loss.append(loss.item())
y_probs_test = torch.cat((y_probs_test, outputs.data), 0)
y_true_test += y_batch.cpu().numpy().tolist()
# returning prob distribution over target classes, take softmax over the 1st dimension
y_probs_test = torch.nn.functional.softmax(y_probs_test, dim=1).cpu().numpy()
y_true_test = np.array(y_true_test)
# record training auroc and testing auroc
auroc_train_history.append(roc_auc_score(y_true_train, y_probs_train[:, 1]))
auroc_test_history.append(roc_auc_score(y_true_test, y_probs_test[:, 1]))
# record training balanced accuracy and testing balanced accuracy
balACC_train_history.append(balanced_accuracy_score(y_true_train, y_pred_train))
balACC_test_history.append(balanced_accuracy_score(y_true_test, y_pred_test))
# LOSS - epoch loss is defined as mean of minibatch losses within epoch
loss_train_history.append(np.mean(train_loss))
loss_test_history.append(np.mean(minibatch_loss))
# print the metrics
print("Train loss: {}, test loss: {}".format(loss_train_history[-1], loss_test_history[-1]))
print("Train AUC: {}, test AUC: {}".format(auroc_train_history[-1], auroc_test_history[-1]))
print("Train Bal.ACC: {}, test Bal.ACC: {}".format(balACC_train_history[-1], balACC_test_history[-1]))
# save model from each epoch====================================================================================
state = {
'epochs': _NUM_EPOCHS,
'experiment_name': _EXPERIMENT_NAME,
'model_description': str(model),
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, f"{_EXPERIMENT_NAME}_Epoch_{epoch}.ckpt")
import torch.nn as nn
import torch.nn.functional as function
from dgl.nn import GraphConv, SumPooling
class EEGGraphConvNet(nn.Module):
""" EEGGraph Convolution Net
Parameters
----------
num_feats: the number of features per node. In our case, it is 6.
"""
def __init__(self, num_feats):
super(EEGGraphConvNet, self).__init__()
self.conv1 = GraphConv(num_feats, 32)
self.conv2 = GraphConv(32, 20)
self.conv2_bn = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.fc_block1 = nn.Linear(20, 10)
self.fc_block2 = nn.Linear(10, 2)
# Xavier initializations
self.fc_block1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
def forward(self, g, return_graph_embedding=False):
x = g.ndata['x']
edge_weight = g.edata['edge_weights']
x = function.leaky_relu(self.conv1(g, x, edge_weight=edge_weight))
x = function.leaky_relu(self.conv2_bn(self.conv2(g, x, edge_weight=edge_weight)))
# NOTE: this takes node-level features/"embeddings"
# and aggregates to graph-level - use for graph-level classification
sumpool = SumPooling()
out = sumpool(g, x)
if return_graph_embedding:
return out
out = function.dropout(out, p=0.2, training=self.training)
out = self.fc_block1(out)
out = function.leaky_relu(out)
out = self.fc_block2(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment