"docs/vscode:/vscode.git/clone" did not exist on "dc3e0ca59bf26ebcc9f12ed186bfe8fca86c3a1b"
Unverified Commit ee241699 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

GAT model (#37)

* GAT model

* fix output projection to have only one head
parent 4673b96f
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
"""
import networkx as nx
from dgl.graph import DGLGraph
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataset import load_data, preprocess_features
import numpy as np
class NodeReduceModule(nn.Module):
def __init__(self, input_dim, num_hidden, num_heads=3, input_dropout=None,
attention_dropout=None):
super(NodeReduceModule, self).__init__()
self.num_heads = num_heads
self.input_dropout = input_dropout
self.attention_dropout = attention_dropout
self.fc = nn.ModuleList(
[nn.Linear(input_dim, num_hidden, bias=False)
for _ in range(num_heads)])
self.attention = nn.ModuleList(
[nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)])
def forward(self, msgs):
src, dst = zip(*msgs)
hu = torch.cat(src, dim=0) # neighbor repr
hv = torch.cat(dst, dim=0)
msgs_repr = []
# iterate for each head
for i in range(self.num_heads):
# calc W*hself and W*hneigh
hvv = self.fc[i](hv)
huu = self.fc[i](hu)
# calculate W*hself||W*hneigh
h = torch.cat((hvv, huu), dim=1)
a = F.leaky_relu(self.attention[i](h))
a = F.softmax(a, dim=0)
if self.attention_dropout is not None:
a = F.dropout(a, self.attention_dropout)
if self.input_dropout is not None:
hvv = F.dropout(hvv, self.input_dropout)
h = torch.sum(a * hvv, 0, keepdim=True)
msgs_repr.append(h)
return msgs_repr
class NodeUpdateModule(nn.Module):
def __init__(self, residual, fc, act, aggregator):
super(NodeUpdateModule, self).__init__()
self.residual = residual
self.fc = fc
self.act = act
self.aggregator = aggregator
def forward(self, node, msgs_repr):
# apply residual connection and activation for each head
for i in range(len(msgs_repr)):
if self.residual:
h = self.fc[i](node['h'])
msgs_repr[i] = msgs_repr[i] + h
if self.act is not None:
msgs_repr[i] = self.act(msgs_repr[i])
# aggregate multi-head results
h = self.aggregator(msgs_repr)
return {'h': h}
class GAT(nn.Module):
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads,
activation, input_dropout, attention_dropout, use_residual=False):
super(GAT, self).__init__()
self.input_dropout = input_dropout
self.reduce_layers = nn.ModuleList()
self.update_layers = nn.ModuleList()
# hidden layers
for i in range(num_layers):
if i == 0:
last_dim = in_dim
residual = False
else:
last_dim = num_hidden * num_heads # because of concat heads
residual = use_residual
self.reduce_layers.append(
NodeReduceModule(last_dim, num_hidden, num_heads, input_dropout,
attention_dropout))
self.update_layers.append(
NodeUpdateModule(residual, self.reduce_layers[-1].fc, activation,
lambda x: torch.cat(x, 1)))
# projection
self.reduce_layers.append(
NodeReduceModule(num_hidden * num_heads, num_classes, 1, input_dropout,
attention_dropout))
self.update_layers.append(
NodeUpdateModule(False, self.reduce_layers[-1].fc, None, sum))
def forward(self, g):
g.register_message_func(lambda src, dst, edge: (src['h'], dst['h']))
for reduce_func, update_func in zip(self.reduce_layers, self.update_layers):
# apply dropout
if self.input_dropout is not None:
# TODO (lingfan): use batched dropout once we have better api
# for global manipulation
for n in g.nodes():
g.node[n]['h'] = F.dropout(g.node[n]['h'], p=self.input_dropout)
g.register_reduce_func(reduce_func)
g.register_update_func(update_func)
g.update_all()
logits = [g.node[n]['h'] for n in g.nodes()]
logits = torch.cat(logits, dim=0)
return logits
def main(args):
# dropout parameters
input_dropout = 0.2
attention_dropout = 0.2
# load and preprocess dataset
adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask = load_data(args.dataset)
features = preprocess_features(features)
# initialize graph
g = DGLGraph(adj)
# create model
model = GAT(args.num_layers,
features.shape[1],
args.num_hidden,
y_train.shape[1],
args.num_heads,
F.elu,
input_dropout,
attention_dropout,
args.residual)
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# convert labels and masks to tensor
labels = torch.FloatTensor(y_train)
mask = torch.FloatTensor(train_mask.astype(np.float32))
n_train = torch.sum(mask)
for epoch in range(args.epochs):
# reset grad
optimizer.zero_grad()
# reset graph states
for n in g.nodes():
g.node[n]['h'] = torch.FloatTensor(features[n].toarray())
# forward
logits = model.forward(g)
# masked cross entropy loss
# TODO: (lingfan) use gather to speed up
logp = F.log_softmax(logits, 1)
loss = -torch.sum(logp * labels * mask.view(-1, 1)) / n_train
print("epoch {} loss: {}".format(epoch, loss.item()))
loss.backward()
optimizer.step()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
parser.add_argument("--dataset", type=str, required=True,
help="dataset name")
parser.add_argument("--epochs", type=int, default=10,
help="training epoch")
parser.add_argument("--num-heads", type=int, default=3,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
parser.add_argument("--residual", action="store_true",
help="use residual connection")
parser.add_argument("--lr", type=float, default=0.001,
help="learning rate")
args = parser.parse_args()
print(args)
main(args)
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
def forward(self, node, msgs):
a1 = torch.unsqueeze(node['a1'], 0) # shape (1, 1)
a2 = torch.cat([torch.unsqueeze(m['a2'], 0) for m in msgs], dim=0) # shape (deg, 1)
ft = torch.cat([torch.unsqueeze(m['ft'], 0) for m in msgs], dim=0) # shape (deg, D)
# attention
a = a1 + a2 # shape (deg, 1)
e = F.softmax(F.leaky_relu(a), dim=0)
if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=0) # shape (D,)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
super(GATFinalize, self).__init__()
self.headid = headid
self.activation = activation
self.residual = residual
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum):
ret = accum
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
else:
ret = node['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(nn.Module):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = nn.Linear(indim, hiddendim)
self.drop = drop
self.attn_l = nn.Linear(hiddendim, 1)
self.attn_r = nn.Linear(hiddendim, 1)
def forward(self, feats):
h = feats
if self.drop != 0.0:
h = F.dropout(h, self.drop)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}
class GAT(nn.Module):
def __init__(self,
nx_graph,
num_layers,
in_dim,
num_hidden,
num_classes,
num_heads,
activation,
in_drop,
attn_drop,
residual):
super(GAT, self).__init__()
self.g = DGLGraph(nx_graph)
self.num_layers = num_layers # one extra output projection
self.num_heads = num_heads
self.prp = nn.ModuleList()
self.red = nn.ModuleList()
self.fnl = nn.ModuleList()
# input projection (no residual)
for hid in range(num_heads):
self.prp.append(GATPrepare(in_dim, num_hidden, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
# hidden layers
for l in range(num_layers - 1):
for hid in range(num_heads):
# due to multi-head, the in_dim = num_hidden * num_heads
self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, num_hidden * num_heads,
num_hidden, activation, residual))
# output projection
self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(0, num_hidden * num_heads, num_classes, activation, residual))
# sanity check
assert len(self.prp) == self.num_layers * self.num_heads + 1
assert len(self.red) == self.num_layers * self.num_heads + 1
assert len(self.fnl) == self.num_layers * self.num_heads + 1
def forward(self, features, train_nodes):
last = features
for l in range(self.num_layers):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
for n, h in last.items():
self.g.nodes[n].update(self.prp[i](h))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
last = {}
for n in self.g.nodes():
last[n] = torch.cat(
[self.g.nodes[n]['head%d' % hid] for hid in range(self.num_heads)])
# output projection
for n, h in last.items():
self.g.nodes[n].update(self.prp[-1](h))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return torch.cat([torch.unsqueeze(self.g.nodes[n]['head0'], 0) for n in train_nodes])
def main(args):
# load and preprocess dataset
data = load_data(args)
# features of each samples
features = {}
labels = []
train_nodes = []
for n in data.graph.nodes():
features[n] = torch.FloatTensor(data.features[n, :])
if data.train_mask[n] == 1:
train_nodes.append(n)
labels.append(data.labels[n])
labels = torch.LongTensor(labels)
in_feats = data.features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = {k : v.cuda() for k, v in features.items()}
labels = labels.cuda()
# create model
model = GAT(data.graph,
args.num_layers,
in_feats,
args.num_hidden,
n_classes,
args.num_heads,
F.elu,
args.in_drop,
args.attn_drop,
args.residual)
if cuda:
model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# initialize graph
dur = []
for epoch in range(args.epochs):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features, train_nodes)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
parser.add_argument("--residual", action="store_false",
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
args = parser.parse_args()
print(args)
main(args)
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GAT with batch processing
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
def forward(self, node, msgs):
a1 = torch.unsqueeze(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=1) # shape (B, D)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
super(GATFinalize, self).__init__()
self.headid = headid
self.activation = activation
self.residual = residual
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum):
ret = accum
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
else:
ret = node['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(nn.Module):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = nn.Linear(indim, hiddendim)
self.drop = drop
self.attn_l = nn.Linear(hiddendim, 1)
self.attn_r = nn.Linear(hiddendim, 1)
def forward(self, feats):
h = feats
if self.drop != 0.0:
h = F.dropout(h, self.drop)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}
class GAT(nn.Module):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
num_heads,
activation,
in_drop,
attn_drop,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.num_heads = num_heads
self.prp = nn.ModuleList()
self.red = nn.ModuleList()
self.fnl = nn.ModuleList()
# input projection (no residual)
for hid in range(num_heads):
self.prp.append(GATPrepare(in_dim, num_hidden, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, in_dim, num_hidden, activation, False))
# hidden layers
for l in range(num_layers - 1):
for hid in range(num_heads):
# due to multi-head, the in_dim = num_hidden * num_heads
self.prp.append(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(hid, num_hidden * num_heads,
num_hidden, activation, residual))
# output projection
self.prp.append(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
self.red.append(GATReduce(attn_drop))
self.fnl.append(GATFinalize(0, num_hidden * num_heads,
num_classes, activation, residual))
# sanity check
assert len(self.prp) == self.num_layers * self.num_heads + 1
assert len(self.red) == self.num_layers * self.num_heads + 1
assert len(self.fnl) == self.num_layers * self.num_heads + 1
def forward(self, features):
last = features
for l in range(self.num_layers):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.set_n_repr(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i], batchable=True)
# merge all the heads
last = torch.cat(
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection
self.g.set_n_repr(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1], batchable=True)
return self.g.pop_n_repr('head0')
def main(args):
# load and preprocess dataset
data = load_data(args)
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
g = DGLGraph(data.graph)
if cuda:
g.set_device(dgl.gpu(args.gpu))
# create model
model = GAT(g,
args.num_layers,
in_feats,
args.num_hidden,
n_classes,
args.num_heads,
F.elu,
args.in_drop,
args.attn_drop,
args.residual)
if cuda:
model.cuda()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# initialize graph
dur = []
for epoch in range(args.epochs):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=3,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
parser.add_argument("--residual", action="store_false",
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
args = parser.parse_args()
print(args)
main(args)
......@@ -572,8 +572,6 @@ class DGLGraph(DiGraph):
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) == 0:
msgs_reduced = None
elif len(msgs_batch) == 1:
msgs_reduced = msgs_batch[0]
else:
msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch)
# update phase
......@@ -581,17 +579,48 @@ class DGLGraph(DiGraph):
_set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func):
v_is_all = is_all(v)
if v_is_all:
f_update = update_func
reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
if all_reduced_msgs is None:
# no message; only do recv.
if is_all(v):
self.set_n_repr(f_update(self.get_n_repr(), None))
else:
self.set_n_repr(f_update(self.get_n_repr(v), None), v)
else:
# Read the node states in the degree-bucketing order.
reordered_ns = self.get_n_repr(reordered_v)
new_ns = f_update(reordered_ns, all_reduced_msgs)
if is_all(v):
# First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v)
# TODO(minjie): manually convert ids to context.
indices = F.to_context(indices, self.context)
if isinstance(new_ns, dict):
for key, val in new_ns.items():
self._node_frame[key] = F.gather_row(val, indices)
else:
self._node_frame[__REPR__] = F.gather_row(new_ns, indices)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, reordered_v)
def _batch_reduce(self, v, reduce_func):
if is_all(v) and len(self._msg_frame) == 0:
# no message has been sent
return None, None
if is_all(v):
v = list(range(self.number_of_nodes()))
# sanity checks
v = utils.convert_to_id_tensor(v)
f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
reduced_msgs = []
for deg, v_bkt in zip(degrees, v_buckets):
if deg == 0:
continue
bkt_len = len(v_bkt)
uu, vv = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
......@@ -611,31 +640,22 @@ class DGLGraph(DiGraph):
dst_reprs = self.get_n_repr(v_bkt)
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
if len(reduced_msgs) == 0:
# no message has been sent to the specified node
return None, None
# TODO: clear partial messages
self.clear_messages()
# Read the node states in the degree-bucketing order.
reordered_v = F.pack(v_buckets)
reordered_ns = self.get_n_repr(reordered_v)
# Pack all reduced msgs together
if isinstance(reduced_msgs[0], dict):
all_reduced_msgs = {key : F.pack(val) for key, val in reduced_msgs.items()}
else:
all_reduced_msgs = F.pack(reduced_msgs)
new_ns = f_update(reordered_ns, all_reduced_msgs)
if v_is_all:
# First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v)
# TODO(minjie): manually convert ids to context.
indices = F.to_context(indices, self.context)
if isinstance(new_ns, dict):
for key, val in new_ns.items():
self._node_frame[key] = F.gather_row(val, indices)
else:
self._node_frame[__REPR__] = F.gather_row(new_ns, indices)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, reordered_v)
return reordered_v, all_reduced_msgs
def update_by_edge(self,
u, v,
......
......@@ -8,7 +8,7 @@ def message_not_called(hu, e_uv):
assert False
return hu
def reduce_not_called(msgs):
def reduce_not_called(h, msgs):
assert False
return 0
......@@ -70,18 +70,7 @@ def test_recv_no_pred():
g.register_update_func(update_no_msg)
g.recv(0)
def test_skipped_reduce():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_func)
g.sendto(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
if __name__ == '__main__':
test_no_msg_update()
test_double_recv()
test_recv_no_pred()
test_skipped_reduce()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment