Commit 896dc50e authored by Hao Zhang's avatar Hao Zhang Committed by Lingfan Yu
Browse files

[Model] fix bugs in pygat and mxgat. (#323)

* mxgat

* pygat

* Update gat_batch.py

* add train

* Update gat_batch.py

* Update gat.py

* Update gat_batch.py
parent 75e2af79
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
......@@ -11,31 +12,37 @@ import numpy as np
import time
import mxnet as mx
from mxnet import gluon
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu')
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
return {'ft': edges.src['ft'], 'a2': edges.src['a2']}
class GATReduce(gluon.Block):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
if attn_drop:
self.attn_drop = gluon.nn.Dropout(attn_drop)
else:
self.attn_drop = 0
def forward(self, nodes):
a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a))
if self.attn_drop != 0.0:
e = mx.nd.Dropout(e, self.attn_drop)
return {'accum' : mx.nd.sum(e * ft, axis=1)} # shape (B, D)
e = self.attn_drop(e)
return {'accum': mx.nd.sum(e * ft, axis=1)} # shape (B, D)
class GATFinalize(gluon.Block):
def __init__(self, headid, indim, hiddendim, activation, residual):
......@@ -46,7 +53,7 @@ class GATFinalize(gluon.Block):
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim)
self.residual_fc = gluon.nn.Dense(hiddendim, use_bias=False)
def forward(self, nodes):
ret = nodes.data['accum']
......@@ -57,22 +64,27 @@ class GATFinalize(gluon.Block):
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(gluon.Block):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = gluon.nn.Dense(hiddendim)
self.drop = drop
self.attn_l = gluon.nn.Dense(1)
self.attn_r = gluon.nn.Dense(1)
if drop:
self.drop = gluon.nn.Dropout(drop)
else:
self.drop = 0
self.attn_l = gluon.nn.Dense(1, use_bias=False)
self.attn_r = gluon.nn.Dense(1, use_bias=False)
def forward(self, feats):
h = feats
if self.drop != 0.0:
h = mx.nd.Dropout(h, self.drop)
h = self.drop(h)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}
return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
class GAT(gluon.Block):
def __init__(self,
......@@ -134,27 +146,42 @@ class GAT(gluon.Block):
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
def evaluate(model, features, labels, mask):
logits = model(features)
logits = logits[mask].asnumpy().squeeze()
val_labels = labels[mask].asnumpy().squeeze()
max_index = np.argmax(logits, axis=1)
accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels)
return accuracy
def main(args):
# load and preprocess dataset
data = load_data(args)
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
mask = mx.nd.array(data.train_mask)
mask = mx.nd.array(np.where(data.train_mask == 1))
test_mask = mx.nd.array(np.where(data.test_mask == 1))
val_mask = mx.nd.array(np.where(data.val_mask == 1))
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
ctx = mx.cpu(0)
ctx = mx.cpu()
else:
ctx = mx.gpu(args.gpu)
features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx)
mask = mask.as_in_context(ctx)
# create GCN model
test_mask = test_mask.as_in_context(ctx)
val_mask = val_mask.as_in_context(ctx)
# create graph
g = DGLGraph(data.graph)
# add self-loop
g.add_edges(g.nodes(), g.nodes())
# create model
model = GAT(g,
......@@ -173,7 +200,6 @@ def main(args):
# use optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})
# initialize graph
dur = []
for epoch in range(args.epochs):
if epoch >= 3:
......@@ -181,26 +207,30 @@ def main(args):
# forward
with mx.autograd.record():
logits = model(features)
loss = mx.nd.softmax_cross_entropy(logits, labels)
#optimizer.zero_grad()
loss.backward()
trainer.step(features.shape[0])
loss.wait_to_read()
loss = mx.nd.softmax_cross_entropy(logits[mask].squeeze(), labels[mask].squeeze())
loss.backward()
trainer.step(mask.shape[0])
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
val_accuracy = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(val_accuracy))
test_accuracy = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(test_accuracy))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20,
parser.add_argument("--epochs", type=int, default=1000,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=3,
parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
......
......@@ -2,7 +2,6 @@
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GAT with batch processing
"""
......@@ -12,28 +11,33 @@ import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
return {'ft': edges.src['ft'], 'a2': edges.src['a2']}
class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
if attn_drop:
self.attn_drop = nn.Dropout(p=attn_drop)
else:
self.attn_drop = 0
def forward(self, nodes):
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop)
return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D)
if self.attn_drop:
e = self.attn_drop(e)
return {'accum': torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
......@@ -44,7 +48,8 @@ class GATFinalize(nn.Module):
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
self.residual_fc = nn.Linear(indim, hiddendim, bias=False)
nn.init.xavier_normal_(self.residual_fc.weight.data, gain=1.414)
def forward(self, nodes):
ret = nodes.data['accum']
......@@ -53,24 +58,32 @@ class GATFinalize(nn.Module):
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
return {'head%d' % self.headid: self.activation(ret)}
class GATPrepare(nn.Module):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = nn.Linear(indim, hiddendim)
self.drop = drop
self.attn_l = nn.Linear(hiddendim, 1)
self.attn_r = nn.Linear(hiddendim, 1)
self.fc = nn.Linear(indim, hiddendim, bias=False)
if drop:
self.drop = nn.Dropout(drop)
else:
self.drop = 0
self.attn_l = nn.Linear(hiddendim, 1, bias=False)
self.attn_r = nn.Linear(hiddendim, 1, bias=False)
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.weight.data, gain=1.414)
def forward(self, feats):
h = feats
if self.drop != 0.0:
h = F.dropout(h, self.drop)
if self.drop:
h = self.drop(h)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}
return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
class GAT(nn.Module):
def __init__(self,
......@@ -132,6 +145,18 @@ class GAT(nn.Module):
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
......@@ -139,6 +164,8 @@ def main(args):
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask)
val_mask = torch.ByteTensor(data.val_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
......@@ -151,10 +178,12 @@ def main(args):
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
val_mask = val_mask.cuda()
# create GCN model
# create DGL graph
g = DGLGraph(data.graph)
# add self loop
g.add_edges(g.nodes(), g.nodes())
# create model
model = GAT(g,
args.num_layers,
......@@ -166,22 +195,23 @@ def main(args):
args.in_drop,
args.attn_drop,
args.residual)
if cuda:
model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph
dur = []
begin_time = time.time()
for epoch in range(args.epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
......@@ -189,31 +219,42 @@ def main(args):
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
acc = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(acc))
end_time = time.time()
print((end_time-begin_time)/args.epochs)
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=3,
help="number of attentional heads to use")
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=10000,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
help="size of hidden units")
parser.add_argument("--residual", action="store_false",
help="use residual connection")
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout")
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
help="learning rate")
parser.add_argument('--weight_decay', type=float, default=5e-4)
args = parser.parse_args()
print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment