Unverified Commit efae0f97 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Model] Improve GAT models (#348)

* two better GAT implementations

* update numbers

* use version switch for spmm

* add missing dropout and output heads
parent 3a868eb0
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
......
......@@ -2,15 +2,48 @@ Graph Attention Networks (GAT)
============
- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)
- Author's code repo:
- Author's code repo (in Tensorflow):
[https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
- Popular pytorch implementation:
[https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
Note that the original code is implemented with Tensorflow for the paper.
Requirements
------------
- torch v1.0: the autograd support for sparse mm is only available in v1.0.
- requests
Results
-------
```bash
pip install torch==1.0.0 requests
```
How to run
----------
Run with following:
```bash
python train.py --dataset=cora --gpu=0
```
Run with following (available dataset: "cora", "citeseer", "pubmed")
```bash
python gat.py --dataset cora --gpu 0 --num-heads 8
python train.py --dataset=citeseer --gpu=0
```
```bash
python train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001
```
Results
-------
| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| ------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.0% | 0.0127 | 0.0982 (**7.7x**) | 0.0424 (**3.3x**) |
| Citeseer | 70.7% | 0.0123 | n/a | n/a |
| Pubmed | 78.1% | 0.0302 | n/a | n/a |
* All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch.
* All time is measured on EC2 p3.2xlarge instance w/ V100 GPU.
* Baseline#1: [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
* Baseline#2: [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
"""
Graph Attention Networks
Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement
multiple output attention heads.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GAT with batch processing
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import argparse
......@@ -13,77 +20,79 @@ import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
import dgl.function as fn
def gat_message(edges):
return {'ft': edges.src['ft'], 'a2': edges.src['a2']}
class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
class GraphAttention(nn.Module):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = None
if attn_drop:
self.attn_drop = nn.Dropout(p=attn_drop)
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = 0
def forward(self, nodes):
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop:
e = self.attn_drop(e)
return {'accum': torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
super(GATFinalize, self).__init__()
self.headid = headid
self.activation = activation
self.attn_drop = None
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
self.leaky_relu = nn.LeakyReLU(alpha)
self.residual = residual
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim, bias=False)
nn.init.xavier_normal_(self.residual_fc.weight.data, gain=1.414)
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(nodes.data['h']) + ret
if in_dim != out_dim:
self.residual_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
else:
ret = nodes.data['h'] + ret
return {'head%d' % self.headid: self.activation(ret)}
self.residual_fc = None
class GATPrepare(nn.Module):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = nn.Linear(indim, hiddendim, bias=False)
if drop:
self.drop = nn.Dropout(drop)
def forward(self, inputs):
# prepare
h = inputs
if self.feat_drop:
h = self.feat_drop(h)
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))
head_ft = ft.transpose(0, 1)
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1)
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1)
if self.feat_drop:
ft = self.feat_drop(ft)
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute two results, one is the node features scaled by the dropped,
# unnormalized attention values. Another is the normalizer of the attention values.
self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')],
[fn.sum('ft', 'ft'), fn.sum('a', 'z')])
# 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z']
# 4. residual
if self.residual:
if self.residual_fc:
ret = self.residual_fc(h) + ret
else:
self.drop = 0
self.attn_l = nn.Linear(hiddendim, 1, bias=False)
self.attn_r = nn.Linear(hiddendim, 1, bias=False)
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.weight.data, gain=1.414)
def forward(self, feats):
h = feats
if self.drop:
h = self.drop(h)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
ret = h + ret
return ret
def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
a = torch.exp(a).clamp(-10, 10) # use clamp to avoid overflow
if self.attn_drop:
a_drop = self.attn_drop(a)
return {'a' : a, 'a_drop' : a_drop}
class GAT(nn.Module):
def __init__(self,
......@@ -94,57 +103,42 @@ class GAT(nn.Module):
num_classes,
num_heads,
activation,
in_drop,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.num_heads = num_heads
self.prp = nn.ModuleList()
self.red = nn.ModuleList()
self.fnl = nn.ModuleList()
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
for hid in range(num_heads):
self.prp.append(GATPrepare(in_dim, num_hidden, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, num_heads, feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(num_layers - 1):
for hid in range(num_heads):
# due to multi-head, the in_dim = num_hidden * num_heads
self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, num_hidden * num_heads,
num_hidden, activation, residual))
self.gat_layers.append(GraphAttention(
g, num_hidden * num_heads, num_hidden, num_heads,
feat_drop, attn_drop, alpha, residual))
# output projection
self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(0, num_hidden * num_heads,
num_classes, activation, residual))
# sanity check
assert len(self.prp) == self.num_layers * self.num_heads + 1
assert len(self.red) == self.num_layers * self.num_heads + 1
assert len(self.fnl) == self.num_layers * self.num_heads + 1
self.gat_layers.append(GraphAttention(
g, num_hidden * num_heads, num_classes, 8,
feat_drop, attn_drop, alpha, residual))
def forward(self, features):
last = features
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.ndata.update(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
last = torch.cat(
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
h = self.gat_layers[l](h).flatten(1)
h = self.activation(h)
# output projection
self.g.ndata.update(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
logits = self.gat_layers[-1](h).sum(1)
return logits
def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def evaluate(model, features, labels, mask):
model.eval()
......@@ -152,23 +146,29 @@ def evaluate(model, features, labels, mask):
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
return accuracy(logits, labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
test_mask = torch.ByteTensor(data.test_mask)
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
in_feats = features.shape[1]
test_mask = torch.ByteTensor(data.test_mask)
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.sum().item(),
val_mask.sum().item(),
test_mask.sum().item()))
if args.gpu < 0:
cuda = False
......@@ -177,41 +177,44 @@ def main(args):
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
train_mask = train_mask.cuda()
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
# create DGL graph
g = DGLGraph(data.graph)
n_edges = g.number_of_edges()
# add self loop
g.add_edges(g.nodes(), g.nodes())
# create model
model = GAT(g,
args.num_layers,
in_feats,
num_feats,
args.num_hidden,
n_classes,
args.num_heads,
F.elu,
args.in_drop,
args.attn_drop,
args.alpha,
args.residual)
print(model)
if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph
dur = []
begin_time = time.time()
for epoch in range(args.epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
......@@ -219,34 +222,40 @@ def main(args):
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if epoch % 100 == 0:
acc = evaluate(model, features, labels, val_mask)
print("Validation Accuracy {:.4f}".format(acc))
train_acc = accuracy(logits[train_mask], labels[train_mask])
if args.fastmode:
val_acc = accuracy(logits[val_mask], labels[val_mask])
else:
val_acc = evaluate(model, features, labels, val_mask)
end_time = time.time()
print((end_time-begin_time)/args.epochs)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
format(epoch, np.mean(dur), loss.item(), train_acc,
val_acc, n_edges / np.mean(dur) / 1000))
print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=10000,
help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=300,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use")
help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
parser.add_argument("--residual", action="store_false",
help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
......@@ -254,7 +263,12 @@ if __name__ == '__main__':
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--weight-decay', type=float, default=5e-4,
help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu")
parser.add_argument('--fastmode', action="store_true", default=False,
help="skip re-evaluate the validation set")
args = parser.parse_args()
print(args)
......
......@@ -124,8 +124,15 @@ def zeros_like(input):
def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx)
def spmm(x, y):
if TH_VERSION.version[0] == 0:
# TODO(minjie): note this does not support autograd on the `x` tensor.
# should adopt a workaround using custom op.
def spmm(x, y):
return th.spmm(x, y)
else:
# torch v1.0+
def spmm(x, y):
return th.sparse.mm(x, y)
def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment