"test/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "aa264980d38d7d957f3bf6835fbcae260ba1dfea"
Commit 896dc50e authored by Hao Zhang's avatar Hao Zhang Committed by Lingfan Yu
Browse files

[Model] fix bugs in pygat and mxgat. (#323)

* mxgat

* pygat

* Update gat_batch.py

* add train

* Update gat_batch.py

* Update gat.py

* Update gat_batch.py
parent 75e2af79
""" """
Graph Attention Networks Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
...@@ -11,31 +12,37 @@ import numpy as np ...@@ -11,31 +12,37 @@ import numpy as np
import time import time
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def elu(data): def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu') return mx.nd.LeakyReLU(data, act_type='elu')
def gat_message(edges): def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']} return {'ft': edges.src['ft'], 'a2': edges.src['a2']}
class GATReduce(gluon.Block): class GATReduce(gluon.Block):
def __init__(self, attn_drop): def __init__(self, attn_drop):
super(GATReduce, self).__init__() super(GATReduce, self).__init__()
self.attn_drop = attn_drop if attn_drop:
self.attn_drop = gluon.nn.Dropout(attn_drop)
else:
self.attn_drop = 0
def forward(self, nodes): def forward(self, nodes):
a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1) a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1) a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D) ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention # attention
a = a1 + a2 # shape (B, deg, 1) a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a)) e = mx.nd.softmax(mx.nd.LeakyReLU(a))
if self.attn_drop != 0.0: if self.attn_drop != 0.0:
e = mx.nd.Dropout(e, self.attn_drop) e = self.attn_drop(e)
return {'accum' : mx.nd.sum(e * ft, axis=1)} # shape (B, D) return {'accum': mx.nd.sum(e * ft, axis=1)} # shape (B, D)
class GATFinalize(gluon.Block): class GATFinalize(gluon.Block):
def __init__(self, headid, indim, hiddendim, activation, residual): def __init__(self, headid, indim, hiddendim, activation, residual):
...@@ -46,7 +53,7 @@ class GATFinalize(gluon.Block): ...@@ -46,7 +53,7 @@ class GATFinalize(gluon.Block):
self.residual_fc = None self.residual_fc = None
if residual: if residual:
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim) self.residual_fc = gluon.nn.Dense(hiddendim, use_bias=False)
def forward(self, nodes): def forward(self, nodes):
ret = nodes.data['accum'] ret = nodes.data['accum']
...@@ -57,22 +64,27 @@ class GATFinalize(gluon.Block): ...@@ -57,22 +64,27 @@ class GATFinalize(gluon.Block):
ret = nodes.data['h'] + ret ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)} return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(gluon.Block): class GATPrepare(gluon.Block):
def __init__(self, indim, hiddendim, drop): def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__() super(GATPrepare, self).__init__()
self.fc = gluon.nn.Dense(hiddendim) self.fc = gluon.nn.Dense(hiddendim)
self.drop = drop if drop:
self.attn_l = gluon.nn.Dense(1) self.drop = gluon.nn.Dropout(drop)
self.attn_r = gluon.nn.Dense(1) else:
self.drop = 0
self.attn_l = gluon.nn.Dense(1, use_bias=False)
self.attn_r = gluon.nn.Dense(1, use_bias=False)
def forward(self, feats): def forward(self, feats):
h = feats h = feats
if self.drop != 0.0: if self.drop != 0.0:
h = mx.nd.Dropout(h, self.drop) h = self.drop(h)
ft = self.fc(h) ft = self.fc(h)
a1 = self.attn_l(ft) a1 = self.attn_l(ft)
a2 = self.attn_r(ft) a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2} return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
class GAT(gluon.Block): class GAT(gluon.Block):
def __init__(self, def __init__(self,
...@@ -134,27 +146,42 @@ class GAT(gluon.Block): ...@@ -134,27 +146,42 @@ class GAT(gluon.Block):
self.g.update_all(gat_message, self.red[-1], self.fnl[-1]) self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0') return self.g.pop_n_repr('head0')
def evaluate(model, features, labels, mask):
logits = model(features)
logits = logits[mask].asnumpy().squeeze()
val_labels = labels[mask].asnumpy().squeeze()
max_index = np.argmax(logits, axis=1)
accuracy = np.sum(np.where(max_index == val_labels, 1, 0)) / len(val_labels)
return accuracy
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
features = mx.nd.array(data.features) features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels) labels = mx.nd.array(data.labels)
mask = mx.nd.array(data.train_mask) mask = mx.nd.array(np.where(data.train_mask == 1))
test_mask = mx.nd.array(np.where(data.test_mask == 1))
val_mask = mx.nd.array(np.where(data.val_mask == 1))
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
if args.gpu < 0: if args.gpu < 0:
ctx = mx.cpu(0) ctx = mx.cpu()
else: else:
ctx = mx.gpu(args.gpu) ctx = mx.gpu(args.gpu)
features = features.as_in_context(ctx) features = features.as_in_context(ctx)
labels = labels.as_in_context(ctx) labels = labels.as_in_context(ctx)
mask = mask.as_in_context(ctx) mask = mask.as_in_context(ctx)
test_mask = test_mask.as_in_context(ctx)
# create GCN model val_mask = val_mask.as_in_context(ctx)
# create graph
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
# add self-loop
g.add_edges(g.nodes(), g.nodes())
# create model # create model
model = GAT(g, model = GAT(g,
...@@ -173,7 +200,6 @@ def main(args): ...@@ -173,7 +200,6 @@ def main(args):
# use optimizer # use optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr}) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})
# initialize graph
dur = [] dur = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
if epoch >= 3: if epoch >= 3:
...@@ -181,26 +207,30 @@ def main(args): ...@@ -181,26 +207,30 @@ def main(args):
# forward # forward
with mx.autograd.record(): with mx.autograd.record():
logits = model(features) logits = model(features)
loss = mx.nd.softmax_cross_entropy(logits, labels) loss = mx.nd.softmax_cross_entropy(logits[mask].squeeze(), labels[mask].squeeze())
loss.backward()
#optimizer.zero_grad() trainer.step(mask.shape[0])
loss.backward()
trainer.step(features.shape[0])
loss.wait_to_read()
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000)) epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
val_accuracy = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(val_accuracy))
test_accuracy = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(test_accuracy))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.") help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20, parser.add_argument("--epochs", type=int, default=1000,
help="number of training epochs") help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=3, parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use") help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1, parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers") help="number of hidden layers")
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Graph Attention Networks Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT Code: https://github.com/PetarV-/GAT
GAT with batch processing GAT with batch processing
""" """
...@@ -12,28 +11,33 @@ import time ...@@ -12,28 +11,33 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gat_message(edges): def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']} return {'ft': edges.src['ft'], 'a2': edges.src['a2']}
class GATReduce(nn.Module): class GATReduce(nn.Module):
def __init__(self, attn_drop): def __init__(self, attn_drop):
super(GATReduce, self).__init__() super(GATReduce, self).__init__()
self.attn_drop = attn_drop if attn_drop:
self.attn_drop = nn.Dropout(p=attn_drop)
else:
self.attn_drop = 0
def forward(self, nodes): def forward(self, nodes):
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1) a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1) a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D) ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention # attention
a = a1 + a2 # shape (B, deg, 1) a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1) e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop != 0.0: if self.attn_drop:
e = F.dropout(e, self.attn_drop) e = self.attn_drop(e)
return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D) return {'accum': torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module): class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual): def __init__(self, headid, indim, hiddendim, activation, residual):
...@@ -44,7 +48,8 @@ class GATFinalize(nn.Module): ...@@ -44,7 +48,8 @@ class GATFinalize(nn.Module):
self.residual_fc = None self.residual_fc = None
if residual: if residual:
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim, bias=False)
nn.init.xavier_normal_(self.residual_fc.weight.data, gain=1.414)
def forward(self, nodes): def forward(self, nodes):
ret = nodes.data['accum'] ret = nodes.data['accum']
...@@ -53,24 +58,32 @@ class GATFinalize(nn.Module): ...@@ -53,24 +58,32 @@ class GATFinalize(nn.Module):
ret = self.residual_fc(nodes.data['h']) + ret ret = self.residual_fc(nodes.data['h']) + ret
else: else:
ret = nodes.data['h'] + ret ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)} return {'head%d' % self.headid: self.activation(ret)}
class GATPrepare(nn.Module): class GATPrepare(nn.Module):
def __init__(self, indim, hiddendim, drop): def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__() super(GATPrepare, self).__init__()
self.fc = nn.Linear(indim, hiddendim) self.fc = nn.Linear(indim, hiddendim, bias=False)
self.drop = drop if drop:
self.attn_l = nn.Linear(hiddendim, 1) self.drop = nn.Dropout(drop)
self.attn_r = nn.Linear(hiddendim, 1) else:
self.drop = 0
self.attn_l = nn.Linear(hiddendim, 1, bias=False)
self.attn_r = nn.Linear(hiddendim, 1, bias=False)
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.weight.data, gain=1.414)
def forward(self, feats): def forward(self, feats):
h = feats h = feats
if self.drop != 0.0: if self.drop:
h = F.dropout(h, self.drop) h = self.drop(h)
ft = self.fc(h) ft = self.fc(h)
a1 = self.attn_l(ft) a1 = self.attn_l(ft)
a2 = self.attn_r(ft) a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2} return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
...@@ -132,6 +145,18 @@ class GAT(nn.Module): ...@@ -132,6 +145,18 @@ class GAT(nn.Module):
self.g.update_all(gat_message, self.red[-1], self.fnl[-1]) self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0') return self.g.pop_n_repr('head0')
def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
...@@ -139,6 +164,8 @@ def main(args): ...@@ -139,6 +164,8 @@ def main(args):
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask) mask = torch.ByteTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask)
val_mask = torch.ByteTensor(data.val_mask)
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -151,10 +178,12 @@ def main(args): ...@@ -151,10 +178,12 @@ def main(args):
features = features.cuda() features = features.cuda()
labels = labels.cuda() labels = labels.cuda()
mask = mask.cuda() mask = mask.cuda()
val_mask = val_mask.cuda()
# create GCN model # create DGL graph
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
# add self loop
g.add_edges(g.nodes(), g.nodes())
# create model # create model
model = GAT(g, model = GAT(g,
args.num_layers, args.num_layers,
...@@ -166,22 +195,23 @@ def main(args): ...@@ -166,22 +195,23 @@ def main(args):
args.in_drop, args.in_drop,
args.attn_drop, args.attn_drop,
args.residual) args.residual)
if cuda: if cuda:
model.cuda() model.cuda()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph # initialize graph
dur = [] dur = []
begin_time = time.time()
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train()
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features) logits = model(features)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels) loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -189,31 +219,42 @@ def main(args): ...@@ -189,31 +219,42 @@ def main(args):
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000)) epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
acc = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(acc))
end_time = time.time()
print((end_time-begin_time)/args.epochs)
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.") help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20, parser.add_argument("--epochs", type=int, default=10000,
help="number of training epochs") help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=3, parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use") help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1, parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers") help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8, parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units") help="size of hidden units")
parser.add_argument("--residual", action="store_false", parser.add_argument("--residual", action="store_false",
help="use residual connection") help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6, parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout") help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6, parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout") help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005, parser.add_argument("--lr", type=float, default=0.005,
help="learning rate") help="learning rate")
parser.add_argument('--weight_decay', type=float, default=5e-4)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment