"...text-generation-inference.git" did not exist on "f7f61876cff78d934b96cb80a8b312d5f9600802"
Commit 0d650443 authored by Ivan Brugere's avatar Ivan Brugere
Browse files

cleaning

cleaning and documentation
parent ef18dab2
......@@ -4,22 +4,22 @@
Created on Mon Jul 9 13:34:38 2018
@author: ivabruge
"""
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GeniePath: Graph Neural Networks with Adaptive Receptive Paths
Paper: https://arxiv.org/abs/1802.00910
this model uses an LSTM on the node reductions of the message-passing step
we store the network states at the graph node, since the LSTM variables are not transmitted
"""
import networkx as nx
from dgl.graph import DGLGraph
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataset import load_data, preprocess_features
import numpy as np
class NodeReduceModule(nn.Module):
def __init__(self, input_dim, num_hidden, num_heads=3, input_dropout=None,
......@@ -101,10 +101,10 @@ class NodeUpdateModule(nn.Module):
return {'h': h, 'c':c, 'h_i':h_i}
class GiniPath(nn.Module):
class GeniePath(nn.Module):
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads,
activation, input_dropout, attention_dropout, use_residual=False ):
super(GiniPath, self).__init__()
super(GeniePath, self).__init__()
self.input_dropout = input_dropout
self.reduce_layers = nn.ModuleList()
......@@ -147,15 +147,18 @@ class GiniPath(nn.Module):
logits = [g.node[n]['h'] for n in g.nodes()]
logits = torch.cat(logits, dim=0)
return logits
#train on graph g with features, and target labels. Accepts a loss function and an optimizer function which implements optimizer.step()
def train(self, g, features, labels, epochs, loss_f=torch.nn.NLLLoss, loss_params={}, optimizer=torch.optim.Adam, optimizer_parameters=None, lr=0.001, ignore=[0], quiet=False):
labels = torch.LongTensor(labels)
print(labels)
_, labels = torch.max(labels, dim=1)
# convert labels and masks to tensor
if optimizer_parameters is None:
optimizer_parameters = self.parameters()
#instantiate optimizer on given params
optimizer_f = optimizer(optimizer_parameters, lr)
for epoch in range(args.epochs):
......@@ -169,7 +172,10 @@ class GiniPath(nn.Module):
# forward
logits = self.forward(g)
#intantiate loss on passed parameters (e.g. class weight params)
loss = loss_f(**loss_params)
#trim null labels
idx = [i for i, a in enumerate(labels) if a not in ignore]
logits = logits[idx, :]
labels = labels[idx]
......@@ -183,8 +189,8 @@ class GiniPath(nn.Module):
def main(args):
# dropout parameters
input_dropout = 0.2
attention_dropout = 0.2
input_dropout = args.idrop
attention_dropout = args.adrop
# load and preprocess dataset
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
......@@ -194,7 +200,7 @@ def main(args):
g = DGLGraph(adj)
# create model
model = GiniPath(args.num_layers,
model = GeniePath(args.num_layers,
features.shape[1],
args.num_hidden,
y_train.shape[1],
......@@ -203,7 +209,7 @@ def main(args):
input_dropout,
attention_dropout,
args.residual)
model.train(g, features, y_train, epochs=10)
model.train(g, features, y_train, epochs=args.epochs)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
......@@ -221,6 +227,11 @@ if __name__ == '__main__':
help="use residual connection")
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate")
parser.add_argument("--idrop", type=float, default=0.2,
help="Input dropout")
parser.add_argument("--adrop", type=float, default=0.2,
help="attention dropout")
args = parser.parse_args()
print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment