Unverified Commit ddc2faa5 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] Fix gat residual bug (#355)

* fix gat residual bug

* fix the residual addition; output heads; add some shape notations;

* minor

* fix the output head average

* add requests package in requirement
parent efae0f97
...@@ -4,6 +4,14 @@ Graph Convolutional Networks (GCN) ...@@ -4,6 +4,14 @@ Graph Convolutional Networks (GCN)
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn) Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)
Requirements
------------
- requests
``bash
pip install requests
``
Codes Codes
----- -----
The folder contains two implementations of GCN. `gcn.py` uses user-defined The folder contains two implementations of GCN. `gcn.py` uses user-defined
...@@ -47,33 +55,33 @@ new information in the concatenations. ...@@ -47,33 +55,33 @@ new information in the concatenations.
``` ```
# Final accuracy 75.34% MLP without GCN # Final accuracy 75.34% MLP without GCN
DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 0 DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_concat.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 0
# Final accuracy 86.57% with 10-layer GCN (symmetric normalization) # Final accuracy 86.57% with 10-layer GCN (symmetric normalization)
DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_concat.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop
# Final accuracy 84.42% with 10-layer GCN (unnormalized) # Final accuracy 84.42% with 10-layer GCN (unnormalized)
DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_batch.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 10 DGLBACKEND=mxnet python examples/mxnet/gcn/gcn_concat.py --dataset "citeseer" --n-epochs 200 --gpu 1 --n-layers 10
``` ```
``` ```
# Final accuracy 40.62% MLP without GCN # Final accuracy 40.62% MLP without GCN
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 0 DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 0
# Final accuracy 92.63% with 10-layer GCN (symmetric normalization) # Final accuracy 92.63% with 10-layer GCN (symmetric normalization)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop
# Final accuracy 86.60% with 10-layer GCN (unnormalized) # Final accuracy 86.60% with 10-layer GCN (unnormalized)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 10 DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "cora" --n-epochs 200 --gpu 1 --n-layers 10
``` ```
``` ```
# Final accuracy 72.97% MLP without GCN # Final accuracy 72.97% MLP without GCN
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 0 DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 0
# Final accuracy 88.33% with 10-layer GCN (symmetric normalization) # Final accuracy 88.33% with 10-layer GCN (symmetric normalization)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 10 --normalization 'sym' --self-loop
# Final accuracy 83.80% with 10-layer GCN (unnormalized) # Final accuracy 83.80% with 10-layer GCN (unnormalized)
DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_batch.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 10 DGLBACKEND=mxnet python3 examples/mxnet/gcn/gcn_concat.py --dataset "pubmed" --n-epochs 200 --gpu 1 --n-layers 10
``` ```
...@@ -3,7 +3,7 @@ Graph Attention Networks in DGL using SPMV optimization. ...@@ -3,7 +3,7 @@ Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training. Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement Compared with the original paper, this code does not implement
multiple output attention heads. early stopping.
References References
---------- ----------
...@@ -53,37 +53,38 @@ class GraphAttention(nn.Module): ...@@ -53,37 +53,38 @@ class GraphAttention(nn.Module):
self.residual = residual self.residual = residual
if residual: if residual:
if in_dim != out_dim: if in_dim != out_dim:
self.residual_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False) self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
else: else:
self.residual_fc = None self.res_fc = None
def forward(self, inputs): def forward(self, inputs):
# prepare # prepare
h = inputs h = inputs # NxD
if self.feat_drop: if self.feat_drop:
h = self.feat_drop(h) h = self.feat_drop(h)
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
head_ft = ft.transpose(0, 1) head_ft = ft.transpose(0, 1) # HxNxD'
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1
if self.feat_drop: if self.feat_drop:
ft = self.feat_drop(ft) ft = self.feat_drop(ft)
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention # 1. compute edge attention
self.g.apply_edges(self.edge_attention) self.g.apply_edges(self.edge_attention)
# 2. compute two results, one is the node features scaled by the dropped, # 2. compute two results: one is the node features scaled by the dropped,
# unnormalized attention values. Another is the normalizer of the attention values. # unnormalized attention values; another is the normalizer of the attention values.
self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')], self.g.update_all([fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.copy_edge('a', 'a')],
[fn.sum('ft', 'ft'), fn.sum('a', 'z')]) [fn.sum('ft', 'ft'), fn.sum('a', 'z')])
# 3. apply normalizer # 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z'] ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD'
# 4. residual # 4. residual
if self.residual: if self.residual:
if self.residual_fc: if self.res_fc is not None:
ret = self.residual_fc(h) + ret resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
else: else:
ret = h + ret resval = torch.unsqueeze(h, 1) # Nx1xD'
ret = resval + ret
return ret return ret
def edge_attention(self, edges): def edge_attention(self, edges):
...@@ -101,7 +102,7 @@ class GAT(nn.Module): ...@@ -101,7 +102,7 @@ class GAT(nn.Module):
in_dim, in_dim,
num_hidden, num_hidden,
num_classes, num_classes,
num_heads, heads,
activation, activation,
feat_drop, feat_drop,
attn_drop, attn_drop,
...@@ -114,16 +115,16 @@ class GAT(nn.Module): ...@@ -114,16 +115,16 @@ class GAT(nn.Module):
self.activation = activation self.activation = activation
# input projection (no residual) # input projection (no residual)
self.gat_layers.append(GraphAttention( self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, num_heads, feat_drop, attn_drop, alpha, False)) g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False))
# hidden layers # hidden layers
for l in range(num_layers - 1): for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads # due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention( self.gat_layers.append(GraphAttention(
g, num_hidden * num_heads, num_hidden, num_heads, g, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual)) feat_drop, attn_drop, alpha, residual))
# output projection # output projection
self.gat_layers.append(GraphAttention( self.gat_layers.append(GraphAttention(
g, num_hidden * num_heads, num_classes, 8, g, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual)) feat_drop, attn_drop, alpha, residual))
def forward(self, inputs): def forward(self, inputs):
...@@ -132,7 +133,7 @@ class GAT(nn.Module): ...@@ -132,7 +133,7 @@ class GAT(nn.Module):
h = self.gat_layers[l](h).flatten(1) h = self.gat_layers[l](h).flatten(1)
h = self.activation(h) h = self.activation(h)
# output projection # output projection
logits = self.gat_layers[-1](h).sum(1) logits = self.gat_layers[-1](h).mean(1)
return logits return logits
def accuracy(logits, labels): def accuracy(logits, labels):
...@@ -187,12 +188,13 @@ def main(args): ...@@ -187,12 +188,13 @@ def main(args):
# add self loop # add self loop
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = GAT(g, model = GAT(g,
args.num_layers, args.num_layers,
num_feats, num_feats,
args.num_hidden, args.num_hidden,
n_classes, n_classes,
args.num_heads, heads,
F.elu, F.elu,
args.in_drop, args.in_drop,
args.attn_drop, args.attn_drop,
......
...@@ -5,9 +5,17 @@ Graph Convolutional Networks (GCN) ...@@ -5,9 +5,17 @@ Graph Convolutional Networks (GCN)
- Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is - Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn). Note that the original code is
implemented with Tensorflow for the paper. implemented with Tensorflow for the paper.
Requirements
------------
- requests
``bash
pip install requests
``
Codes Codes
----- -----
The folder contains two implementations of GCN. `gcn_batch.py` uses user-defined The folder contains two implementations of GCN. `gcn.py` uses user-defined
message and reduce functions. `gcn_spmv.py` uses DGL's builtin functions so message and reduce functions. `gcn_spmv.py` uses DGL's builtin functions so
SPMV optimization could be applied. SPMV optimization could be applied.
......
...@@ -625,8 +625,6 @@ class GraphIndex(object): ...@@ -625,8 +625,6 @@ class GraphIndex(object):
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx) x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx) y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
print(idx)
print(dat)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else: else:
raise DGLError('Invalid incidence matrix type: %s' % str(typestr)) raise DGLError('Invalid incidence matrix type: %s' % str(typestr))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment