Unverified Commit efae0f97 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Model] Improve GAT models (#348)

* two better GAT implementations

* update numbers

* use version switch for spmm

* add missing dropout and output heads
parent 3a868eb0
""" """
Graph Attention Networks Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
......
...@@ -2,15 +2,48 @@ Graph Attention Networks (GAT) ...@@ -2,15 +2,48 @@ Graph Attention Networks (GAT)
============ ============
- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903) - Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)
- Author's code repo: - Author's code repo (in Tensorflow):
[https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT). [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
- Popular pytorch implementation:
[https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
Note that the original code is implemented with Tensorflow for the paper. Requirements
------------
- torch v1.0: the autograd support for sparse mm is only available in v1.0.
- requests
Results ```bash
------- pip install torch==1.0.0 requests
```
How to run
----------
Run with following:
```bash
python train.py --dataset=cora --gpu=0
```
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash ```bash
python gat.py --dataset cora --gpu 0 --num-heads 8 python train.py --dataset=citeseer --gpu=0
``` ```
```bash
python train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001
```
Results
-------
| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| ------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.0% | 0.0127 | 0.0982 (**7.7x**) | 0.0424 (**3.3x**) |
| Citeseer | 70.7% | 0.0123 | n/a | n/a |
| Pubmed | 78.1% | 0.0302 | n/a | n/a |
* All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch.
* All time is measured on EC2 p3.2xlarge instance w/ V100 GPU.
* Baseline#1: [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
* Baseline#2: [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
""" """
Graph Attention Networks Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement
multiple output attention heads.
References
----------
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT Author's code: https://github.com/PetarV-/GAT
GAT with batch processing Pytorch implementation: https://github.com/Diego999/pyGAT
""" """
import argparse import argparse
...@@ -13,77 +20,79 @@ import torch.nn as nn ...@@ -13,77 +20,79 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
import dgl.function as fn
class GraphAttention(nn.Module):
def gat_message(edges): def __init__(self,
return {'ft': edges.src['ft'], 'a2': edges.src['a2']} g,
in_dim,
out_dim,
class GATReduce(nn.Module): num_heads,
def __init__(self, attn_drop): feat_drop,
super(GATReduce, self).__init__() attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = None
if attn_drop: if attn_drop:
self.attn_drop = nn.Dropout(p=attn_drop) self.attn_drop = nn.Dropout(attn_drop)
else: else:
self.attn_drop = 0 self.attn_drop = None
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
def forward(self, nodes): self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1) nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
ft = nodes.mailbox['ft'] # shape (B, deg, D) nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
# attention self.leaky_relu = nn.LeakyReLU(alpha)
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop:
e = self.attn_drop(e)
return {'accum': torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
super(GATFinalize, self).__init__()
self.headid = headid
self.activation = activation
self.residual = residual self.residual = residual
self.residual_fc = None
if residual: if residual:
if indim != hiddendim: if in_dim != out_dim:
self.residual_fc = nn.Linear(indim, hiddendim, bias=False) self.residual_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
nn.init.xavier_normal_(self.residual_fc.weight.data, gain=1.414) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(nodes.data['h']) + ret
else: else:
ret = nodes.data['h'] + ret self.residual_fc = None
return {'head%d' % self.headid: self.activation(ret)}
class GATPrepare(nn.Module): def forward(self, inputs):
def __init__(self, indim, hiddendim, drop): # prepare
super(GATPrepare, self).__init__() h = inputs
self.fc = nn.Linear(indim, hiddendim, bias=False) if self.feat_drop:
if drop: h = self.feat_drop(h)
self.drop = nn.Dropout(drop) ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))
head_ft = ft.transpose(0, 1)
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)
if self.feat_drop:
ft = self.feat_drop(ft)
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute two results, one is the node features scaled by the dropped,
# unnormalized attention values. Another is the normalizer of the attention values.
self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')],
[fn.sum('ft', 'ft'), fn.sum('a', 'z')])
# 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z']
# 4. residual
if self.residual:
if self.residual_fc:
ret = self.residual_fc(h) + ret
else: else:
self.drop = 0 ret = h + ret
self.attn_l = nn.Linear(hiddendim, 1, bias=False) return ret
self.attn_r = nn.Linear(hiddendim, 1, bias=False)
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.weight.data, gain=1.414)
def forward(self, feats):
h = feats
if self.drop:
h = self.drop(h)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
a = torch.exp(a).clamp(-10, 10) # use clamp to avoid overflow
if self.attn_drop:
a_drop = self.attn_drop(a)
return {'a' : a, 'a_drop' : a_drop}
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
...@@ -94,57 +103,42 @@ class GAT(nn.Module): ...@@ -94,57 +103,42 @@ class GAT(nn.Module):
num_classes, num_classes,
num_heads, num_heads,
activation, activation,
in_drop, feat_drop,
attn_drop, attn_drop,
alpha,
residual): residual):
super(GAT, self).__init__() super(GAT, self).__init__()
self.g = g self.g = g
self.num_layers = num_layers self.num_layers = num_layers
self.num_heads = num_heads self.gat_layers = nn.ModuleList()
self.prp = nn.ModuleList() self.activation = activation
self.red = nn.ModuleList()
self.fnl = nn.ModuleList()
# input projection (no residual) # input projection (no residual)
for hid in range(num_heads): self.gat_layers.append(GraphAttention(
self.prp.append(GATPrepare(in_dim, num_hidden, in_drop)) g, in_dim, num_hidden, num_heads, feat_drop, attn_drop, alpha, False))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
# hidden layers # hidden layers
for l in range(num_layers - 1): for l in range(num_layers - 1):
for hid in range(num_heads):
# due to multi-head, the in_dim = num_hidden * num_heads # due to multi-head, the in_dim = num_hidden * num_heads
self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop)) self.gat_layers.append(GraphAttention(
self.red.append(GATReduce(attn_drop)) g, num_hidden * num_heads, num_hidden, num_heads,
self.fnl.append(GATFinalize(hid, num_hidden * num_heads, feat_drop, attn_drop, alpha, residual))
num_hidden, activation, residual))
# output projection # output projection
self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop)) self.gat_layers.append(GraphAttention(
self.red.append(GATReduce(attn_drop)) g, num_hidden * num_heads, num_classes, 8,
self.fnl.append(GATFinalize(0, num_hidden * num_heads, feat_drop, attn_drop, alpha, residual))
num_classes, activation, residual))
# sanity check
assert len(self.prp) == self.num_layers * self.num_heads + 1
assert len(self.red) == self.num_layers * self.num_heads + 1
assert len(self.fnl) == self.num_layers * self.num_heads + 1
def forward(self, features): def forward(self, inputs):
last = features h = inputs
for l in range(self.num_layers): for l in range(self.num_layers):
for hid in range(self.num_heads): h = self.gat_layers[l](h).flatten(1)
i = l * self.num_heads + hid h = self.activation(h)
# prepare
self.g.ndata.update(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
last = torch.cat(
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection # output projection
self.g.ndata.update(self.prp[-1](last)) logits = self.gat_layers[-1](h).sum(1)
self.g.update_all(gat_message, self.red[-1], self.fnl[-1]) return logits
return self.g.pop_n_repr('head0')
def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
model.eval() model.eval()
...@@ -152,23 +146,29 @@ def evaluate(model, features, labels, mask): ...@@ -152,23 +146,29 @@ def evaluate(model, features, labels, mask):
logits = model(features) logits = model(features)
logits = logits[mask] logits = logits[mask]
labels = labels[mask] labels = labels[mask]
_, indices = torch.max(logits, dim=1) return accuracy(logits, labels)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask) train_mask = torch.ByteTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask)
val_mask = torch.ByteTensor(data.val_mask) val_mask = torch.ByteTensor(data.val_mask)
in_feats = features.shape[1] test_mask = torch.ByteTensor(data.test_mask)
num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().item(),
val_mask.sum().item(),
test_mask.sum().item()))
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
...@@ -177,41 +177,44 @@ def main(args): ...@@ -177,41 +177,44 @@ def main(args):
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
features = features.cuda() features = features.cuda()
labels = labels.cuda() labels = labels.cuda()
mask = mask.cuda() train_mask = train_mask.cuda()
val_mask = val_mask.cuda() val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
# create DGL graph # create DGL graph
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
# add self loop # add self loop
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
# create model # create model
model = GAT(g, model = GAT(g,
args.num_layers, args.num_layers,
in_feats, num_feats,
args.num_hidden, args.num_hidden,
n_classes, n_classes,
args.num_heads, args.num_heads,
F.elu, F.elu,
args.in_drop, args.in_drop,
args.attn_drop, args.attn_drop,
args.alpha,
args.residual) args.residual)
print(model)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph # initialize graph
dur = [] dur = []
begin_time = time.time()
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features) logits = model(features)
logp = F.log_softmax(logits, 1) loss = loss_fcn(logits[train_mask], labels[train_mask])
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -219,34 +222,40 @@ def main(args): ...@@ -219,34 +222,40 @@ def main(args):
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
acc = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(acc))
train_acc = accuracy(logits[train_mask], labels[train_mask])
if args.fastmode:
val_acc = accuracy(logits[val_mask], labels[val_mask])
else:
val_acc = evaluate(model, features, labels, val_mask)
end_time = time.time() print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
print((end_time-begin_time)/args.epochs) " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
format(epoch, np.mean(dur), loss.item(), train_acc,
val_acc, n_edges / np.mean(dur) / 1000))
print()
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.") help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=10000, parser.add_argument("--epochs", type=int, default=300,
help="number of training epochs") help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8, parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use") help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=1, parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers") help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8, parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units") help="number of hidden units")
parser.add_argument("--residual", action="store_false", parser.add_argument("--residual", action="store_true", default=False,
help="use residual connection") help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6, parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout") help="input feature dropout")
...@@ -254,7 +263,12 @@ if __name__ == '__main__': ...@@ -254,7 +263,12 @@ if __name__ == '__main__':
help="attention dropout") help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005, parser.add_argument("--lr", type=float, default=0.005,
help="learning rate") help="learning rate")
parser.add_argument('--weight_decay', type=float, default=5e-4) parser.add_argument('--weight-decay', type=float, default=5e-4,
help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu")
parser.add_argument('--fastmode', action="store_true", default=False,
help="skip re-evaluate the validation set")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -124,8 +124,15 @@ def zeros_like(input): ...@@ -124,8 +124,15 @@ def zeros_like(input):
def ones(shape, dtype, ctx): def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx) return th.ones(shape, dtype=dtype, device=ctx)
def spmm(x, y): if TH_VERSION.version[0] == 0:
# TODO(minjie): note this does not support autograd on the `x` tensor.
# should adopt a workaround using custom op.
def spmm(x, y):
return th.spmm(x, y) return th.spmm(x, y)
else:
# torch v1.0+
def spmm(x, y):
return th.sparse.mm(x, y)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim): def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input) y = th.zeros(n_segs, *input.shape[1:]).to(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment