Commit 74e13eea authored by Lingfan Yu's avatar Lingfan Yu Committed by Minjie Wang
Browse files

[Model] Update GAT model code (#622)

* fix gat code to use latest edge softmax module

* avoid transpose

* update README

* use edge_softmax op

* mxnet edge softmax op

* mxnet gat

* update README

* fix unittest

* fix ci

* fix mxnet nn test; relax criteria for prod reducer
parent 993fd3f9
...@@ -13,16 +13,17 @@ maintaining high computation efficiency. ...@@ -13,16 +13,17 @@ maintaining high computation efficiency.
A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations: A summary of the model accuracy and training speed with the Pytorch backend (on Amazon EC2 p3.2x instance (w/ V100 GPU)), as compared with the best open-source implementations:
| Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement | | Model | Reported <br> Accuracy | DGL <br> Accuracy | Author's training speed (epoch time) | DGL speed (epoch time) | Improvement |
| ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- | | ----- | ----------------- | ------------ | ------------------------------------ | ---------------------- | ----------- |
| [GCN](https://arxiv.org/abs/1609.02907) | 81.5% | 81.0% | [0.0051s (TF)](https://github.com/tkipf/gcn) | 0.0038s | 1.34x | | [GCN](https://arxiv.org/abs/1609.02907) | 81.5% | 81.0% | [0.0051s (TF)](https://github.com/tkipf/gcn) | 0.0038s | 1.34x |
| [SGC](https://arxiv.org/abs/1902.07153) | 81.0% | 81.9% | n/a | 0.0008s | n/a | | [GAT](https://arxiv.org/abs/1710.10903) | 83.0% | 83.9% | [0.0982s (TF)](https://github.com/PetarV-/GAT) | 0.0076s | 12.9x |
| [TreeLSTM](http://arxiv.org/abs/1503.00075) | 51.0% | 51.72% | [14.02s (DyNet)](https://github.com/clab/dynet/tree/master/examples/treelstm) | 3.18s | 4.3x | | [SGC](https://arxiv.org/abs/1902.07153) | 81.0% | 81.9% | n/a | 0.0008s | n/a |
| [R-GCN <br> (classification)](https://arxiv.org/abs/1703.06103) | 73.23% | 73.53% | [0.2853s (Theano)](https://github.com/tkipf/relational-gcn) | 0.0097s | 29.4x | | [TreeLSTM](http://arxiv.org/abs/1503.00075) | 51.0% | 51.72% | [14.02s (DyNet)](https://github.com/clab/dynet/tree/master/examples/treelstm) | 3.18s | 4.3x |
| [R-GCN <br> (link prediction)](https://arxiv.org/abs/1703.06103) | 0.158 | 0.151 | [2.204s (TF)](https://github.com/MichSchli/RelationPrediction) | 0.453s | 4.86x | | [R-GCN <br> (classification)](https://arxiv.org/abs/1703.06103) | 73.23% | 73.53% | [0.2853s (Theano)](https://github.com/tkipf/relational-gcn) | 0.0097s | 29.4x |
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x | | [R-GCN <br> (link prediction)](https://arxiv.org/abs/1703.06103) | 0.158 | 0.151 | [2.204s (TF)](https://github.com/MichSchli/RelationPrediction) | 0.453s | 4.86x |
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a | | [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a | | [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance, With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance,
with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)), with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)),
......
...@@ -19,5 +19,5 @@ pip install requests ...@@ -19,5 +19,5 @@ pip install requests
### Usage (make sure that DGLBACKEND is changed into mxnet) ### Usage (make sure that DGLBACKEND is changed into mxnet)
```bash ```bash
DGLBACKEND=mxnet python3 gat_batch.py --dataset cora --gpu 0 --num-heads 8 DGLBACKEND=mxnet python3 train.py --dataset cora --gpu 0 --num-heads 8
``` ```
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import mxnet as mx
from mxnet import gluon
import mxnet.ndarray as nd
import mxnet.gluon.nn as nn
import dgl.function as fn
from dgl.nn.mxnet import edge_softmax
class GraphAttention(gluon.Block):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Dense(num_heads * out_dim, use_bias=False,
weight_initializer=mx.init.Xavier())
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = lambda x : x
self.attn_l = self.params.get("left_att", grad_req="add",
shape=(1, num_heads, out_dim),
init=mx.init.Xavier())
self.attn_r = self.params.get("right_att", grad_req="add",
shape=(1, num_heads, out_dim),
init=mx.init.Xavier())
self.alpha = alpha
self.softmax = edge_softmax
self.residual = residual
if residual:
if in_dim != out_dim:
self.res_fc = nn.Dense(num_heads * out_dim, use_bias=False,
weight_initializer=mx.init.Xavier())
else:
self.res_fc = None
def forward(self, inputs):
# prepare
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
a1 = (ft * self.attn_l.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1
a2 = (ft * self.attn_r.data(ft.context)).sum(axis=-1).expand_dims(-1) # N x H x 1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute softmax
self.edge_softmax()
# 3. compute the aggregated node features
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'),
fn.sum('ft', 'ft'))
ret = self.g.ndata['ft']
# 4. residual
if self.residual:
if self.res_fc is not None:
resval = self.res_fc(h).reshape(
(h.shape[0], self.num_heads, -1)) # NxHxD'
else:
resval = nd.expand_dims(h, axis=1) # Nx1xD'
ret = resval + ret
return ret
def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = nd.LeakyReLU(edges.src['a1'] + edges.dst['a2'], slope=self.alpha)
return {'a' : a}
def edge_softmax(self):
attention = self.softmax(self.g, self.g.edata.pop('a'))
# Dropout attention scores and save them
self.g.edata['a_drop'] = self.attn_drop(attention)
class GAT(nn.Block):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = []
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, heads[0],
feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))
for i, layer in enumerate(self.gat_layers):
self.register_child(layer, "gat_layer_{}".format(i))
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h).flatten()
h = self.activation(h)
# output projection
logits = self.gat_layers[-1](h).mean(1)
return logits
""" """
Graph Attention Networks Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement
early stopping.
References
----------
Paper: https://arxiv.org/abs/1710.10903 Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
GAT with batch processing
""" """
import argparse import argparse
import numpy as np
import time import time
import mxnet as mx import mxnet as mx
from mxnet import gluon from mxnet import gluon
import numpy as np
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
from gat import GAT
def elu(data): def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu') return mx.nd.LeakyReLU(data, act_type='elu')
def gat_message(edges):
return {'ft': edges.src['ft'], 'a2': edges.src['a2']}
class GATReduce(gluon.Block):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
if attn_drop:
self.attn_drop = gluon.nn.Dropout(attn_drop)
else:
self.attn_drop = 0
def forward(self, nodes):
a1 = mx.nd.expand_dims(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a))
if self.attn_drop != 0.0:
e = self.attn_drop(e)
return {'accum': mx.nd.sum(e * ft, axis=1)} # shape (B, D)
class GATFinalize(gluon.Block):
def __init__(self, headid, indim, hiddendim, activation, residual):
super(GATFinalize, self).__init__()
self.headid = headid
self.activation = activation
self.residual = residual
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim, use_bias=False)
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(gluon.Block):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = gluon.nn.Dense(hiddendim)
if drop:
self.drop = gluon.nn.Dropout(drop)
else:
self.drop = 0
self.attn_l = gluon.nn.Dense(1, use_bias=False)
self.attn_r = gluon.nn.Dense(1, use_bias=False)
def forward(self, feats):
h = feats
if self.drop != 0.0:
h = self.drop(h)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h': h, 'ft': ft, 'a1': a1, 'a2': a2}
class GAT(gluon.Block):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
num_heads,
activation,
in_drop,
attn_drop,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.num_heads = num_heads
self.prp = gluon.nn.Sequential()
self.red = gluon.nn.Sequential()
self.fnl = gluon.nn.Sequential()
# input projection (no residual)
for hid in range(num_heads):
self.prp.add(GATPrepare(in_dim, num_hidden, in_drop))
self.red.add(GATReduce(attn_drop))
self.fnl.add(GATFinalize(hid, in_dim, num_hidden, activation, False))
# hidden layers
for l in range(num_layers - 1):
for hid in range(num_heads):
# due to multi-head, the in_dim = num_hidden * num_heads
self.prp.add(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
self.red.add(GATReduce(attn_drop))
self.fnl.add(GATFinalize(hid, num_hidden * num_heads,
num_hidden, activation, residual))
# output projection
self.prp.add(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
self.red.add(GATReduce(attn_drop))
self.fnl.add(GATFinalize(0, num_hidden * num_heads,
num_classes, activation, residual))
# sanity check
assert len(self.prp) == self.num_layers * self.num_heads + 1
assert len(self.red) == self.num_layers * self.num_heads + 1
assert len(self.fnl) == self.num_layers * self.num_heads + 1
def forward(self, features):
last = features
for l in range(self.num_layers):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.set_n_repr(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
last = mx.nd.concat(
*[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection
self.g.set_n_repr(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
logits = model(features) logits = model(features)
logits = logits[mask].asnumpy().squeeze() logits = logits[mask].asnumpy().squeeze()
...@@ -184,15 +62,17 @@ def main(args): ...@@ -184,15 +62,17 @@ def main(args):
g = DGLGraph(g) g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = GAT(g, model = GAT(g,
args.num_layers, args.num_layers,
in_feats, in_feats,
args.num_hidden, args.num_hidden,
n_classes, n_classes,
args.num_heads, heads,
elu, elu,
args.in_drop, args.in_drop,
args.attn_drop, args.attn_drop,
args.alpha,
args.residual) args.residual)
model.initialize(ctx=ctx) model.initialize(ctx=ctx)
...@@ -224,26 +104,33 @@ def main(args): ...@@ -224,26 +104,33 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.") help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=1000, parser.add_argument("--epochs", type=int, default=200,
help="number of training epochs") help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8, parser.add_argument("--num-heads", type=int, default=8,
help="number of attentional heads to use") help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=1, parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers") help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8, parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units") help="number of hidden units")
parser.add_argument("--residual", action="store_false", parser.add_argument("--residual", action="store_true", default=False,
help="use residual connection") help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6, parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout") help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6, parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout") help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005, parser.add_argument("--lr", type=float, default=0.005,
help="learning rate") help="learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4,
help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -43,9 +43,9 @@ Results ...@@ -43,9 +43,9 @@ Results
| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) | | Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) |
| ------- | ------------- | ------- | ------------------- | ------------------- | | ------- | ------------- | ------- | ------------------- | ------------------- |
| Cora | 84.0% | 0.0127 | 0.0982 (**7.7x**) | 0.0424 (**3.3x**) | | Cora | 84.0% | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) |
| Citeseer | 70.7% | 0.0123 | n/a | n/a | | Citeseer | 70.7% | 0.0111 | n/a | n/a |
| Pubmed | 78.1% | 0.0302 | n/a | n/a | | Pubmed | 78.0% | 0.0115 | n/a | n/a |
* All the accuracy numbers are obtained after 300 epochs. * All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch. * The time measures how long it takes to train one epoch.
......
...@@ -10,7 +10,7 @@ Pytorch implementation: https://github.com/Diego999/pyGAT ...@@ -10,7 +10,7 @@ Pytorch implementation: https://github.com/Diego999/pyGAT
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl.function as fn import dgl.function as fn
from dgl.nn.pytorch import EdgeSoftmax from dgl.nn.pytorch import edge_softmax
class GraphAttention(nn.Module): class GraphAttention(nn.Module):
def __init__(self, def __init__(self,
...@@ -34,13 +34,13 @@ class GraphAttention(nn.Module): ...@@ -34,13 +34,13 @@ class GraphAttention(nn.Module):
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
else: else:
self.attn_drop = lambda x : x self.attn_drop = lambda x : x
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_l = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1))) self.attn_r = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414) nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.data, gain=1.414) nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.data, gain=1.414) nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
self.leaky_relu = nn.LeakyReLU(alpha) self.leaky_relu = nn.LeakyReLU(alpha)
self.softmax = EdgeSoftmax() self.softmax = edge_softmax
self.residual = residual self.residual = residual
if residual: if residual:
if in_dim != out_dim: if in_dim != out_dim:
...@@ -53,19 +53,17 @@ class GraphAttention(nn.Module): ...@@ -53,19 +53,17 @@ class GraphAttention(nn.Module):
# prepare # prepare
h = self.feat_drop(inputs) # NxD h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD' ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
head_ft = ft.transpose(0, 1) # HxNxD' a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1) # N x H x 1
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1 a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1) # N x H x 1
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2}) self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention # 1. compute edge attention
self.g.apply_edges(self.edge_attention) self.g.apply_edges(self.edge_attention)
# 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x))) # 2. compute softmax
self.edge_softmax() self.edge_softmax()
# 2. compute the aggregated node features scaled by the dropped, # 3. compute the aggregated node features scaled by the dropped,
# unnormalized attention values. # unnormalized attention values.
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft')) self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
# 3. apply normalizer ret = self.g.ndata['ft']
ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD'
# 4. residual # 4. residual
if self.residual: if self.residual:
if self.res_fc is not None: if self.res_fc is not None:
...@@ -81,11 +79,9 @@ class GraphAttention(nn.Module): ...@@ -81,11 +79,9 @@ class GraphAttention(nn.Module):
return {'a' : a} return {'a' : a}
def edge_softmax(self): def edge_softmax(self):
scores, normalizer = self.softmax(self.g.edata['a'], self.g) attention = self.softmax(self.g, self.g.edata.pop('a'))
# Save normalizer
self.g.ndata['z'] = normalizer
# Dropout attention scores and save them # Dropout attention scores and save them
self.g.edata['a_drop'] = self.attn_drop(scores) self.g.edata['a_drop'] = self.attn_drop(attention)
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(self,
......
"""Package for mxnet-specific NN modules.""" """Package for mxnet-specific NN modules."""
from .conv import * from .conv import *
from .softmax import *
"""Gluon layer for graph related softmax."""
# pylint: disable= no-member, arguments-differ
import mxnet as mx
from ... import utils
from ... import function as fn
__all__ = ['edge_softmax']
class EdgeSoftmax(mx.autograd.Function):
r"""Apply softmax over signals of incoming edges.
For a node :math:`i`, edgesoftmax is an operation of computing
.. math::
a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of softmax. :math:`\mathcal{N}(i)` is
the set of nodes that have an edge to :math:`i`.
An example of using edgesoftmax is in
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ where
the attention weights are computed with such an edgesoftmax operation.
"""
def __init__(self, g):
super(EdgeSoftmax, self).__init__()
self.g = g
def forward(self, score):
"""
score = dgl.EData(g, score)
score_max = score.dst_max() # of type dgl.NData
score = score - score_max # edge_sub_dst, ret dgl.EData
score_sum = score.dst_sum() # of type dgl.NData
out = score / score_sum # edge_div_dst, ret dgl.EData
return out.data
"""
g = self.g
score_name = utils.get_edata_name(g, 'score')
tmp_name = utils.get_ndata_name(g, 'tmp')
out_name = utils.get_edata_name(g, 'out')
g.edata[score_name] = score
g.update_all(fn.copy_e(score_name, 'm'), fn.max('m', tmp_name))
g.apply_edges(fn.e_sub_v(score_name, tmp_name, out_name))
g.edata[out_name] = g.edata[out_name].exp()
g.update_all(fn.copy_e(out_name, 'm'), fn.sum('m', tmp_name))
g.apply_edges(fn.e_div_v(out_name, tmp_name, out_name))
g.edata.pop(score_name)
g.ndata.pop(tmp_name)
out = g.edata.pop(out_name)
self.save_for_backward(out)
return out
def backward(self, grad_out):
"""
g, out = ctx.backward_cache
grad_out = dgl.EData(g, grad_out)
out = dgl.EData(g, out)
sds = out * grad_out # type dgl.EData
sds_sum = sds.dst_sum() # type dgl.NData
grad_score = sds - sds * sds_sum # multiple expressions
return grad_score.data
"""
g = self.g
out = self.saved_tensors[0]
out_name = utils.get_edata_name(g, 'out')
accum_name = utils.get_ndata_name(g, 'accum')
grad_score_name = utils.get_edata_name(g, 'grad_score')
g.edata[out_name] = out
g.edata[grad_score_name] = out * grad_out
g.update_all(fn.copy_e(grad_score_name, 'm'), fn.sum('m', accum_name))
g.apply_edges(fn.e_mul_v(out_name, accum_name, out_name))
g.ndata.pop(accum_name)
grad_score = g.edata.pop(grad_score_name) - g.edata.pop(out_name)
return grad_score
def edge_softmax(graph, logits):
r"""Compute edge softmax.
Parameters
----------
graph : DGLGraph
The graph to perform edge softmax
logits : torch.Tensor
The input edge feature
Returns
-------
Tensor
Softmax value
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
* Return shape: :math:`(N, *, 1)`
Examples
--------
>>> import dgl.function as fn
>>> attention = EdgeSoftmax(logits, graph)
"""
softmax_op = EdgeSoftmax(graph)
return softmax_op(logits)
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
# pylint: disable= no-member, arguments-differ # pylint: disable= no-member, arguments-differ
import torch as th import torch as th
from ... import backend as F
from ... import utils from ... import utils
from ... import function as fn from ... import function as fn
__all__ = ['EdgeSoftmax', 'edge_softmax'] __all__ = ['edge_softmax']
class EdgeSoftmax(object): class EdgeSoftmax(th.autograd.Function):
r"""Apply softmax over signals of incoming edges. r"""Apply softmax over signals of incoming edges.
For a node :math:`i`, edgesoftmax is an operation of computing For a node :math:`i`, edgesoftmax is an operation of computing
...@@ -26,62 +25,6 @@ class EdgeSoftmax(object): ...@@ -26,62 +25,6 @@ class EdgeSoftmax(object):
the attention weights are computed with such an edgesoftmax operation. the attention weights are computed with such an edgesoftmax operation.
""" """
def __call__(self, graph, logits):
r"""Compute edge softmax.
Parameters
----------
graph : DGLGraph
The graph to perform edge softmax
logits : torch.Tensor
The input edge feature
Returns
-------
Unnormalized scores : torch.Tensor
This part gives :math:`\exp(z_{ij})`'s
Normalizer : torch.Tensor
This part gives :math:`\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})`
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
* Unnormalized scores shape: :math:`(N, *, 1)` where all but the
last dimension are the same shape as the input.
* Normalizer shape: :math:`(M, *, 1)` where :math:`M` is the number
of nodes and all but the first and the last dimensions are the
same as the input.
Note that this computation is still one step away from getting real
softmax results. The last step can be proceeded as follows:
>>> import dgl.function as fn
>>> scores, normalizer = EdgeSoftmax(logits, graph)
>>> graph.edata['a'] = scores
>>> graph.ndata['normalizer'] = normalizer
>>> graph.apply_edges(
lambda edges: {'a': edges.data['a'] / edges.dst['normalizer']})
We left this last step to users as depending on the particular use
case, this step can be combined with other computation at once.
"""
num_nodes = graph.number_of_nodes()
ctx = utils.to_dgl_context(F.context(logits))
gidx = graph._graph.get_immutable_gidx(ctx)
_, dst, _ = graph._graph.edges()
dst = dst.tousertensor(F.context(logits))
empty_map = (None, None)
max_logits_ = F.copy_reduce("max", gidx, fn.TargetCode.EDGE, logits,
num_nodes, empty_map, empty_map)
logits = (logits - max_logits_.index_select(0, dst)).exp()
norm = F.copy_reduce("sum", gidx, fn.TargetCode.EDGE, logits,
num_nodes, empty_map, empty_map)
return logits / norm.index_select(0, dst)
class EdgeSoftmax1(th.autograd.Function):
"""EdgeSoftmax implementation with DGL message passing APIs"""
@staticmethod @staticmethod
def forward(ctx, g, score): def forward(ctx, g, score):
""" """
...@@ -131,4 +74,30 @@ class EdgeSoftmax1(th.autograd.Function): ...@@ -131,4 +74,30 @@ class EdgeSoftmax1(th.autograd.Function):
return None, grad_score return None, grad_score
edge_softmax = EdgeSoftmax1.apply # pylint: disable=invalid-name def edge_softmax(graph, logits):
r"""Compute edge softmax.
Parameters
----------
graph : DGLGraph
The graph to perform edge softmax
logits : torch.Tensor
The input edge feature
Returns
-------
Tensor
Softmax value
Notes
-----
* Input shape: :math:`(N, *, 1)` where * means any number of
additional dimensions, :math:`N` is the number of edges.
* Return shape: :math:`(N, *, 1)`
Examples
--------
>>> import dgl.function as fn
>>> attention = EdgeSoftmax(logits, graph)
"""
return EdgeSoftmax.apply(graph, logits)
...@@ -23,7 +23,7 @@ def array_equal(a, b): ...@@ -23,7 +23,7 @@ def array_equal(a, b):
"""Check whether the two tensors are *exactly* equal.""" """Check whether the two tensors are *exactly* equal."""
pass pass
def allclose(a, b): def allclose(a, b, rtol=1e-4, atol=1e-4):
"""Check whether the two tensors are numerically close to each other.""" """Check whether the two tensors are numerically close to each other."""
pass pass
......
...@@ -19,8 +19,8 @@ def is_cuda_available(): ...@@ -19,8 +19,8 @@ def is_cuda_available():
def array_equal(a, b): def array_equal(a, b):
return nd.equal(a, b).asnumpy().all() return nd.equal(a, b).asnumpy().all()
def allclose(a, b): def allclose(a, b, rtol=1e-4, atol=1e-4):
return np.allclose(a.asnumpy(), b.asnumpy(), rtol=1e-4, atol=1e-4) return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)
def randn(shape): def randn(shape):
return nd.random.randn(*shape) return nd.random.randn(*shape)
......
...@@ -11,9 +11,9 @@ def is_cuda_available(): ...@@ -11,9 +11,9 @@ def is_cuda_available():
def array_equal(a, b): def array_equal(a, b):
return th.equal(a.cpu(), b.cpu()) return th.equal(a.cpu(), b.cpu())
def allclose(a, b): def allclose(a, b, rtol=1e-4, atol=1e-4):
return th.allclose(a.float().cpu(), return th.allclose(a.float().cpu(),
b.float().cpu(), rtol=1e-4, atol=1e-4) b.float().cpu(), rtol=rtol, atol=atol)
def randn(shape): def randn(shape):
return th.randn(*shape) return th.randn(*shape)
......
...@@ -185,17 +185,24 @@ def test_all_binary_builtins(): ...@@ -185,17 +185,24 @@ def test_all_binary_builtins():
print(a) print(a)
print(b) print(b)
if not F.allclose(r1, r2): if reducer == 'prod':
rtol = 1e-2
atol = 1e-2
else:
rtol = 1e-4
atol = 1e-4
if not F.allclose(r1, r2, rtol, atol):
_print_error(r1, r2) _print_error(r1, r2)
assert F.allclose(r1, r2) assert F.allclose(r1, r2, rtol, atol)
if not F.allclose(rhs_grad_1, rhs_grad_2): if not F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol):
print("left grad") print("left grad")
_print_error(lhs_grad_1, lhs_grad_2) _print_error(lhs_grad_1, lhs_grad_2)
assert(F.allclose(lhs_grad_1, lhs_grad_2)) assert(F.allclose(lhs_grad_1, lhs_grad_2, rtol, atol))
if not F.allclose(rhs_grad_1, rhs_grad_2): if not F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol):
print("right grad") print("right grad")
_print_error(rhs_grad_1, rhs_grad_2) _print_error(rhs_grad_1, rhs_grad_2)
assert(F.allclose(rhs_grad_1, rhs_grad_2)) assert(F.allclose(rhs_grad_1, rhs_grad_2, rtol, atol))
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(20) g.add_nodes(20)
......
...@@ -56,5 +56,25 @@ def test_graph_conv(): ...@@ -56,5 +56,25 @@ def test_graph_conv():
h1 = conv(h0, g) h1 = conv(h0, g)
assert "_gconv_feat" in g.ndata assert "_gconv_feat" in g.ndata
def uniform_attention(g, shape):
a = mx.nd.ones(shape)
target_shape = (g.number_of_edges(),) + (1,) * (len(shape) - 1)
return a / g.in_degrees(g.edges()[1]).reshape(target_shape).astype('float32')
def test_edge_softmax():
# Basic
g = dgl.DGLGraph(nx.path_graph(3))
edata = mx.nd.ones((g.number_of_edges(), 1))
a = nn.edge_softmax(g, edata)
assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
1e-4, 1e-4)
# Test higher dimension case
edata = mx.nd.ones((g.number_of_edges(), 3, 1))
a = nn.edge_softmax(g, edata)
assert np.allclose(a.asnumpy(), uniform_attention(g, a.shape).asnumpy(),
1e-4, 1e-4)
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
test_edge_softmax()
...@@ -52,20 +52,17 @@ def uniform_attention(g, shape): ...@@ -52,20 +52,17 @@ def uniform_attention(g, shape):
return a / g.in_degrees(g.edges()[1]).view(target_shape).float() return a / g.in_degrees(g.edges()[1]).view(target_shape).float()
def test_edge_softmax(): def test_edge_softmax():
def _test(edge_softmax): # Basic
# Basic g = dgl.DGLGraph(nx.path_graph(3))
g = dgl.DGLGraph(nx.path_graph(3)) edata = th.ones(g.number_of_edges(), 1)
edata = th.ones(g.number_of_edges(), 1) a = nn.edge_softmax(g, edata)
a = edge_softmax(g, edata) assert th.allclose(a, uniform_attention(g, a.shape))
assert th.allclose(a, uniform_attention(g, a.shape))
# Test higher dimension case # Test higher dimension case
edata = th.ones(g.number_of_edges(), 3, 1) edata = th.ones(g.number_of_edges(), 3, 1)
a = edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert th.allclose(a, uniform_attention(g, a.shape)) assert th.allclose(a, uniform_attention(g, a.shape))
_test(nn.edge_softmax)
_test(nn.EdgeSoftmax())
if __name__ == '__main__': if __name__ == '__main__':
test_graph_conv() test_graph_conv()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment