Commit 1248bd24 authored by Lingfan Yu's avatar Lingfan Yu
Browse files

example gcn model

parent 34fac23d
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import sys
# (lingfan): following dataset loading and preprocessing code from tkipf/gcn
# https://github.com/tkipf/gcn/blob/master/gcn/utils.py
def parse_index_file(filename):
"""Parse index file."""
index = []
for line in open(filename):
index.append(int(line.strip()))
return index
def sample_mask(idx, l):
"""Create mask."""
mask = np.zeros(l)
mask[idx] = 1
return np.array(mask, dtype=np.bool)
def load_data(dataset_str):
"""
Loads input data from gcn/data directory
ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
(a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
object;
ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
All objects above must be saved using python pickle module.
:param dataset_str: Dataset name
:return: All data input files loaded (as well the training/test data).
"""
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
objects = []
for i in range(len(names)):
with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
if sys.version_info > (3, 0):
objects.append(pkl.load(f, encoding='latin1'))
else:
objects.append(pkl.load(f))
x, y, tx, ty, allx, ally, graph = tuple(objects)
test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str))
test_idx_range = np.sort(test_idx_reorder)
if dataset_str == 'citeseer':
# Fix citeseer dataset (there are some isolated nodes in the graph)
# Find isolated nodes, add them as zero-vecs into the right position
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
tx_extended[test_idx_range-min(test_idx_range), :] = tx
tx = tx_extended
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
ty_extended[test_idx_range-min(test_idx_range), :] = ty
ty = ty_extended
features = sp.vstack((allx, tx)).tolil()
features[test_idx_reorder, :] = features[test_idx_range, :]
adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
labels = np.vstack((ally, ty))
labels[test_idx_reorder, :] = labels[test_idx_range, :]
idx_test = test_idx_range.tolist()
idx_train = range(len(y))
idx_val = range(len(y), len(y)+500)
train_mask = sample_mask(idx_train, labels.shape[0])
val_mask = sample_mask(idx_val, labels.shape[0])
test_mask = sample_mask(idx_test, labels.shape[0])
y_train = np.zeros(labels.shape)
y_val = np.zeros(labels.shape)
y_test = np.zeros(labels.shape)
y_train[train_mask, :] = labels[train_mask, :]
y_val[val_mask, :] = labels[val_mask, :]
y_test[test_mask, :] = labels[test_mask, :]
return adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask
def preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features)
return features
import networkx as nx
from dgl.graph import DGLGraph
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataset import load_data, preprocess_features
import numpy as np
class NodeUpdateModule(nn.Module):
def __init__(self, input_dim, output_dim, act=None, p=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.act = act
self.p = p
def forward(self, node, msgs):
h = node['h']
if self.p is not None:
h = F.dropout(h, p=self.p)
# aggregator messages
for msg in msgs:
h += msg
h = self.linear(h)
if self.act is not None:
h = self.act(h)
# (lingfan): Can user directly update node instead of using return statement?
return {'h': h}
class GCN(nn.Module):
def __init__(self, input_dim, num_hidden, num_classes, num_layers, activation, dropout):
super(GCN, self).__init__()
self.layers = nn.ModuleList()
# hidden layers
last_dim = input_dim
for _ in range(num_layers):
self.layers.append(
NodeUpdateModule(last_dim, num_hidden, act=activation, p=dropout))
last_dim = num_hidden
# output layer
self.layers.append(NodeUpdateModule(num_hidden, num_classes, p=dropout))
def forward(self, g):
g.register_message_func(lambda src, dst, edge: src['h'])
for layer in self.layers:
g.register_update_func(layer)
g.update_all()
logits = [g.node[n]['h'] for n in g.nodes()]
return torch.cat(logits, dim=0)
def main(args):
# load and preprocess dataset
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
features = preprocess_features(features)
# initialize graph
g = DGLGraph(adj)
# create GCN model
model = GCN(features.shape[1],
args.num_hidden,
y_train.shape[1],
args.num_layers,
F.relu,
args.dropout)
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# convert labels and masks to tensor
labels = torch.FloatTensor(y_train)
mask = torch.FloatTensor(train_mask.astype(np.float32))
for epoch in range(args.epochs):
# reset grad
optimizer.zero_grad()
# reset graph states
for n in g.nodes():
g.node[n]['h'] = torch.FloatTensor(features[n].toarray())
# forward
logits = model.forward(g)
# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
logp = F.log_softmax(logits, 1)
loss = torch.mean(logp * labels * mask.view(-1, 1))
print("epoch {} loss: {}".format(epoch, loss.item()))
loss.backward()
optimizer.step()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
parser.add_argument("--dataset", type=str, required=True,
help="dataset name")
parser.add_argument("--num-layers", type=int, default=1,
help="number of gcn layers")
parser.add_argument("--num-hidden", type=int, default=64,
help="number of hidden units")
parser.add_argument("--epochs", type=int, default=10,
help="training epoch")
parser.add_argument("--dropout", type=float, default=None,
help="dropout probability")
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate")
args = parser.parse_args()
print(args)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment