Commit 0d650443 authored by Ivan Brugere's avatar Ivan Brugere
Browse files

cleaning

cleaning and documentation
parent ef18dab2
...@@ -4,22 +4,22 @@ ...@@ -4,22 +4,22 @@
Created on Mon Jul 9 13:34:38 2018 Created on Mon Jul 9 13:34:38 2018
@author: ivabruge @author: ivabruge
"""
""" GeniePath: Graph Neural Networks with Adaptive Receptive Paths
Graph Attention Networks Paper: https://arxiv.org/abs/1802.00910
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT this model uses an LSTM on the node reductions of the message-passing step
we store the network states at the graph node, since the LSTM variables are not transmitted
""" """
import networkx as nx
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import argparse import argparse
from dataset import load_data, preprocess_features from dataset import load_data, preprocess_features
import numpy as np
class NodeReduceModule(nn.Module): class NodeReduceModule(nn.Module):
def __init__(self, input_dim, num_hidden, num_heads=3, input_dropout=None, def __init__(self, input_dim, num_hidden, num_heads=3, input_dropout=None,
...@@ -101,10 +101,10 @@ class NodeUpdateModule(nn.Module): ...@@ -101,10 +101,10 @@ class NodeUpdateModule(nn.Module):
return {'h': h, 'c':c, 'h_i':h_i} return {'h': h, 'c':c, 'h_i':h_i}
class GiniPath(nn.Module): class GeniePath(nn.Module):
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads, def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads,
activation, input_dropout, attention_dropout, use_residual=False ): activation, input_dropout, attention_dropout, use_residual=False ):
super(GiniPath, self).__init__() super(GeniePath, self).__init__()
self.input_dropout = input_dropout self.input_dropout = input_dropout
self.reduce_layers = nn.ModuleList() self.reduce_layers = nn.ModuleList()
...@@ -147,15 +147,18 @@ class GiniPath(nn.Module): ...@@ -147,15 +147,18 @@ class GiniPath(nn.Module):
logits = [g.node[n]['h'] for n in g.nodes()] logits = [g.node[n]['h'] for n in g.nodes()]
logits = torch.cat(logits, dim=0) logits = torch.cat(logits, dim=0)
return logits return logits
#train on graph g with features, and target labels. Accepts a loss function and an optimizer function which implements optimizer.step()
def train(self, g, features, labels, epochs, loss_f=torch.nn.NLLLoss, loss_params={}, optimizer=torch.optim.Adam, optimizer_parameters=None, lr=0.001, ignore=[0], quiet=False): def train(self, g, features, labels, epochs, loss_f=torch.nn.NLLLoss, loss_params={}, optimizer=torch.optim.Adam, optimizer_parameters=None, lr=0.001, ignore=[0], quiet=False):
labels = torch.LongTensor(labels) labels = torch.LongTensor(labels)
print(labels)
_, labels = torch.max(labels, dim=1) _, labels = torch.max(labels, dim=1)
# convert labels and masks to tensor # convert labels and masks to tensor
if optimizer_parameters is None: if optimizer_parameters is None:
optimizer_parameters = self.parameters() optimizer_parameters = self.parameters()
#instantiate optimizer on given params
optimizer_f = optimizer(optimizer_parameters, lr) optimizer_f = optimizer(optimizer_parameters, lr)
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -169,7 +172,10 @@ class GiniPath(nn.Module): ...@@ -169,7 +172,10 @@ class GiniPath(nn.Module):
# forward # forward
logits = self.forward(g) logits = self.forward(g)
#intantiate loss on passed parameters (e.g. class weight params)
loss = loss_f(**loss_params) loss = loss_f(**loss_params)
#trim null labels
idx = [i for i, a in enumerate(labels) if a not in ignore] idx = [i for i, a in enumerate(labels) if a not in ignore]
logits = logits[idx, :] logits = logits[idx, :]
labels = labels[idx] labels = labels[idx]
...@@ -183,8 +189,8 @@ class GiniPath(nn.Module): ...@@ -183,8 +189,8 @@ class GiniPath(nn.Module):
def main(args): def main(args):
# dropout parameters # dropout parameters
input_dropout = 0.2 input_dropout = args.idrop
attention_dropout = 0.2 attention_dropout = args.adrop
# load and preprocess dataset # load and preprocess dataset
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset) adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
...@@ -194,7 +200,7 @@ def main(args): ...@@ -194,7 +200,7 @@ def main(args):
g = DGLGraph(adj) g = DGLGraph(adj)
# create model # create model
model = GiniPath(args.num_layers, model = GeniePath(args.num_layers,
features.shape[1], features.shape[1],
args.num_hidden, args.num_hidden,
y_train.shape[1], y_train.shape[1],
...@@ -203,7 +209,7 @@ def main(args): ...@@ -203,7 +209,7 @@ def main(args):
input_dropout, input_dropout,
attention_dropout, attention_dropout,
args.residual) args.residual)
model.train(g, features, y_train, epochs=10) model.train(g, features, y_train, epochs=args.epochs)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description='GAT')
...@@ -221,6 +227,11 @@ if __name__ == '__main__': ...@@ -221,6 +227,11 @@ if __name__ == '__main__':
help="use residual connection") help="use residual connection")
parser.add_argument("--lr", type=float, default=0.001, parser.add_argument("--lr", type=float, default=0.001,
help="learning rate") help="learning rate")
parser.add_argument("--idrop", type=float, default=0.2,
help="Input dropout")
parser.add_argument("--adrop", type=float, default=0.2,
help="attention dropout")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment