"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e242de9cb57eded6ee745ba85ee8f3faa63e4567"
Unverified Commit 756fdd8e authored by xnouhz's avatar xnouhz Committed by GitHub
Browse files

[Example] DeeperGCN (#2831)



* [example] deepergcn

* update

* update

* update

* update

* update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 1fe08607
...@@ -91,8 +91,10 @@ The folder contains example implementations of selected research papers related ...@@ -91,8 +91,10 @@ The folder contains example implementations of selected research papers related
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | | | [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
| [Interaction Networks for Learning about Objects, Relations and Physics](#graphsim) | | |:heavy_check_mark: | | | | [Interaction Networks for Learning about Objects, Relations and Physics](#graphsim) | | |:heavy_check_mark: | | |
| [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | | | [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | |
| [DeeperGCN: All You Need to Train Deeper GCNs](#deepergcn) | | | :heavy_check_mark: | | :heavy_check_mark: |
| [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting](#dcrnn) | | | :heavy_check_mark: | | | | [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting](#dcrnn) | | | :heavy_check_mark: | | |
| [GaAN: Gated Attention Networks for Learning on large and Spatiotemporal Graphs](#gaan) | | | :heavy_check_mark: | | | | [GaAN: Gated Attention Networks for Learning on large and Spatiotemporal Graphs](#gaan) | | | :heavy_check_mark: | | |
## 2021 ## 2021
- <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL). - <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL).
...@@ -136,17 +138,21 @@ The folder contains example implementations of selected research papers related ...@@ -136,17 +138,21 @@ The folder contains example implementations of selected research papers related
- Tags: molecules, molecular property prediction, quantum chemistry - Tags: molecules, molecular property prediction, quantum chemistry
- <a name="dagnn"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637). - <a name="dagnn"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637).
- Example code: [Pytorch](../examples/pytorch/tgn) - Example code: [PyTorch](../examples/pytorch/tgn)
- Tags: over-smoothing, node classification - Tags: over-smoothing, node classification
- <a name="dagnn"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637). - <a name="dagnn"></a> Rossi et al. Temporal Graph Networks For Deep Learning on Dynamic Graphs. [Paper link](https://arxiv.org/abs/2006.10637).
- Example code: [Pytorch](../examples/pytorch/tgn) - Example code: [PyTorch](../examples/pytorch/tgn)
- Tags: over-smoothing, node classification - Tags: over-smoothing, node classification
- <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082). - <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082).
- Example code: [Pytorch](../examples/pytorch/compGCN) - Example code: [PyTorch](../examples/pytorch/compGCN)
- Tags: multi-relational graphs, graph neural network - Tags: multi-relational graphs, graph neural network
- <a name="deepergcn"></a> Li et al. DeeperGCN: All You Need to Train Deeper GCNs. [Paper link](https://arxiv.org/abs/2006.07739).
- Example code: [PyTorch](../examples/pytorch/deepergcn)
- Tags: over-smoothing, deeper gnn, OGB
## 2019 ## 2019
...@@ -262,11 +268,11 @@ The folder contains example implementations of selected research papers related ...@@ -262,11 +268,11 @@ The folder contains example implementations of selected research papers related
- Tags: graph classification - Tags: graph classification
- <a name="seal"></a> Zhang et al. Link Prediction Based on Graph Neural Networks. [Paper link](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf). - <a name="seal"></a> Zhang et al. Link Prediction Based on Graph Neural Networks. [Paper link](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf).
- Example code: [pytorch](../examples/pytorch/seal) - Example code: [PyTorch](../examples/pytorch/seal)
- Tags: link prediction, sampling - Tags: link prediction, sampling
- <a name="jknet"></a> Xu et al. Representation Learning on Graphs with Jumping Knowledge Networks. [Paper link](https://arxiv.org/abs/1806.03536). - <a name="jknet"></a> Xu et al. Representation Learning on Graphs with Jumping Knowledge Networks. [Paper link](https://arxiv.org/abs/1806.03536).
- Example code: [pytorch](../examples/pytorch/jknet) - Example code: [PyTorch](../examples/pytorch/jknet)
- Tags: message passing, neighborhood - Tags: message passing, neighborhood
- <a name="gaan"></a> Zhang et al. GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs. [Paper link](https://arxiv.org/abs/1803.07294). - <a name="gaan"></a> Zhang et al. GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs. [Paper link](https://arxiv.org/abs/1803.07294).
......
# DGL Implementation of DeeperGCN
This DGL example implements the GNN model proposed in the paper [DeeperGCN: All You Need to Train Deeper GCNs](https://arxiv.org/abs/2006.07739). For the original implementation, see [here](https://github.com/lightaime/deep_gcns_torch).
Contributor: [xnuohz](https://github.com/xnuohz)
### Requirements
The codebase is implemented in Python 3.7. For version requirement of packages, see below.
```
dgl 0.6.0.post1
torch 1.7.0
ogb 1.3.0
```
### The graph datasets used in this example
Open Graph Benchmark(OGB). Dataset summary:
###### Graph Property Prediction
| Dataset | #Graphs | #Node Feats | #Edge Feats | Metric |
| :---------: | :-----: | :---------: | :---------: | :-----: |
| ogbg-molhiv | 41,127 | 9 | 3 | ROC-AUC |
### Usage
Train a model which follows the original hyperparameters on different datasets.
```bash
# ogbg-molhiv
python main.py --gpu 0 --learn-beta
```
### Performance
* Table 6: Numbers associated with "Table 6" are the ones from table 6 in the paper.
* Author: Numbers associated with "Author" are the ones we got by running the original code.
* DGL: Numbers associated with "DGL" are the ones we got by running the DGL example.
| Dataset | ogbg-molhiv |
| :--------------: | :---------: |
| Results(Table 6) | 0.786 |
| Results(Author) | 0.781 |
| Results(DGL) | 0.778 |
### Speed
| Dataset | ogbg-molhiv |
| :-------------: | :---------: |
| Results(Author) | 11.833 |
| Results(DGL) | 8.965 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from ogb.graphproppred.mol_encoder import BondEncoder
from dgl.nn.functional import edge_softmax
from modules import MLP, MessageNorm
class GENConv(nn.Module):
r"""
Description
-----------
Generalized Message Aggregator was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
Parameters
----------
in_dim: int
Input size.
out_dim: int
Output size.
aggregator: str
Type of aggregation. Default is 'softmax'.
beta: float
A continuous variable called an inverse temperature. Default is 1.0.
learn_beta: bool
Whether beta is a learnable variable or not. Default is False.
p: float
Initial power for power mean aggregation. Default is 1.0.
learn_p: bool
Whether p is a learnable variable or not. Default is False.
msg_norm: bool
Whether message normalization is used. Default is False.
learn_msg_scale: bool
Whether s is a learnable scaling factor or not in message normalization. Default is False.
mlp_layers: int
The number of MLP layers. Default is 1.
eps: float
A small positive constant in message construction function. Default is 1e-7.
"""
def __init__(self,
in_dim,
out_dim,
aggregator='softmax',
beta=1.0,
learn_beta=False,
p=1.0,
learn_p=False,
msg_norm=False,
learn_msg_scale=False,
mlp_layers=1,
eps=1e-7):
super(GENConv, self).__init__()
self.aggr = aggregator
self.eps = eps
channels = [in_dim]
for _ in range(mlp_layers - 1):
channels.append(in_dim * 2)
channels.append(out_dim)
self.mlp = MLP(channels)
self.msg_norm = MessageNorm(learn_msg_scale) if msg_norm else None
self.beta = nn.Parameter(torch.Tensor([beta]), requires_grad=True) if learn_beta and self.aggr == 'softmax' else beta
self.p = nn.Parameter(torch.Tensor([p]), requires_grad=True) if learn_p else p
self.edge_encoder = BondEncoder(in_dim)
def forward(self, g, node_feats, edge_feats):
with g.local_scope():
# Node and edge feature size need to match.
g.ndata['h'] = node_feats
g.edata['h'] = self.edge_encoder(edge_feats)
g.apply_edges(fn.u_add_e('h', 'h', 'm'))
if self.aggr == 'softmax':
g.edata['m'] = F.relu(g.edata['m']) + self.eps
g.edata['a'] = edge_softmax(g, g.edata['m'] * self.beta)
g.update_all(lambda edge: {'x': edge.data['m'] * edge.data['a']},
fn.sum('x', 'm'))
elif self.aggr == 'power':
minv, maxv = 1e-7, 1e1
torch.clamp_(g.edata['m'], minv, maxv)
g.update_all(lambda edge: {'x': torch.pow(edge.data['m'], self.p)},
fn.mean('x', 'm'))
torch.clamp_(g.ndata['m'], minv, maxv)
g.ndata['m'] = torch.pow(g.ndata['m'], self.p)
else:
raise NotImplementedError(f'Aggregator {self.aggr} is not supported.')
if self.msg_norm is not None:
g.ndata['m'] = self.msg_norm(node_feats, g.ndata['m'])
feats = node_feats + g.ndata['m']
return self.mlp(feats)
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import time
from ogb.graphproppred import DglGraphPropPredDataset, collate_dgl
from torch.utils.data import DataLoader
from ogb.graphproppred import Evaluator
from models import DeeperGCN
def train(model, device, data_loader, opt, loss_fn):
model.train()
train_loss = []
for g, labels in data_loader:
g = g.to(device)
labels = labels.to(torch.float32).to(device)
logits = model(g, g.edata['feat'], g.ndata['feat'])
loss = loss_fn(logits, labels)
train_loss.append(loss.item())
opt.zero_grad()
loss.backward()
opt.step()
return sum(train_loss) / len(train_loss)
@torch.no_grad()
def test(model, device, data_loader, evaluator):
model.eval()
y_true, y_pred = [], []
for g, labels in data_loader:
g = g.to(device)
logits = model(g, g.edata['feat'], g.ndata['feat'])
y_true.append(labels.detach().cpu())
y_pred.append(logits.detach().cpu())
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
return evaluator.eval({
'y_true': y_true,
'y_pred': y_pred
})['rocauc']
def main():
# check cuda
device = f'cuda:{args.gpu}' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu'
# load ogb dataset & evaluator
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
evaluator = Evaluator(name='ogbg-molhiv')
g, _ = dataset[0]
node_feat_dim = g.ndata['feat'].size()[-1]
edge_feat_dim = g.edata['feat'].size()[-1]
n_classes = dataset.num_tasks
split_idx = dataset.get_idx_split()
train_loader = DataLoader(dataset[split_idx["train"]],
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_dgl)
valid_loader = DataLoader(dataset[split_idx["valid"]],
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_dgl)
test_loader = DataLoader(dataset[split_idx["test"]],
batch_size=args.batch_size,
shuffle=False,
collate_fn=collate_dgl)
# load model
model = DeeperGCN(node_feat_dim=node_feat_dim,
edge_feat_dim=edge_feat_dim,
hid_dim=args.hid_dim,
out_dim=n_classes,
num_layers=args.num_layers,
dropout=args.dropout,
learn_beta=args.learn_beta).to(device)
print(model)
opt = optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.BCEWithLogitsLoss()
# training & validation & testing
best_auc = 0
best_model = copy.deepcopy(model)
times = []
print('---------- Training ----------')
for i in range(args.epochs):
t1 = time.time()
train_loss = train(model, device, train_loader, opt, loss_fn)
t2 = time.time()
if i >= 5:
times.append(t2 - t1)
train_auc = test(model, device, train_loader, evaluator)
valid_auc = test(model, device, valid_loader, evaluator)
print(f'Epoch {i} | Train Loss: {train_loss:.4f} | Train Auc: {train_auc:.4f} | Valid Auc: {valid_auc:.4f}')
if valid_auc > best_auc:
best_auc = valid_auc
best_model = copy.deepcopy(model)
print('---------- Testing ----------')
test_auc = test(best_model, device, test_loader, evaluator)
print(f'Test Auc: {test_auc}')
if len(times) > 0:
print('Times/epoch: ', sum(times) / len(times))
if __name__ == '__main__':
"""
DeeperGCN Hyperparameters
"""
parser = argparse.ArgumentParser(description='DeeperGCN')
# training
parser.add_argument('--gpu', type=int, default=-1, help='GPU index, -1 for CPU.')
parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.')
parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate.')
parser.add_argument('--batch-size', type=int, default=2048, help='Batch size.')
# model
parser.add_argument('--num-layers', type=int, default=7, help='Number of GNN layers.')
parser.add_argument('--hid-dim', type=int, default=256, help='Hidden channel size.')
# learnable parameters in aggr
parser.add_argument('--learn-beta', action='store_true')
args = parser.parse_args()
print(args)
main()
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from ogb.graphproppred.mol_encoder import AtomEncoder
from dgl.nn.pytorch.glob import AvgPooling
from layers import GENConv
class DeeperGCN(nn.Module):
r"""
Description
-----------
Introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
Parameters
----------
node_feat_dim: int
Size of node feature.
edge_feat_dim: int
Size of edge feature.
hid_dim: int
Size of hidden representations.
out_dim: int
Size of output.
num_layers: int
Number of graph convolutional layers.
dropout: float
Dropout rate. Default is 0.
beta: float
A continuous variable called an inverse temperature. Default is 1.0.
learn_beta: bool
Whether beta is a learnable weight. Default is False.
aggr: str
Type of aggregation. Default is 'softmax'.
mlp_layers: int
Number of MLP layers in message normalization. Default is 1.
"""
def __init__(self,
node_feat_dim,
edge_feat_dim,
hid_dim,
out_dim,
num_layers,
dropout=0.,
beta=1.0,
learn_beta=False,
aggr='softmax',
mlp_layers=1):
super(DeeperGCN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.gcns = nn.ModuleList()
self.norms = nn.ModuleList()
for _ in range(self.num_layers):
conv = GENConv(in_dim=hid_dim,
out_dim=hid_dim,
aggregator=aggr,
beta=beta,
learn_beta=learn_beta,
mlp_layers=mlp_layers)
self.gcns.append(conv)
self.norms.append(nn.BatchNorm1d(hid_dim, affine=True))
self.node_encoder = AtomEncoder(hid_dim)
self.pooling = AvgPooling()
self.output = nn.Linear(hid_dim, out_dim)
def forward(self, g, edge_feats, node_feats=None):
with g.local_scope():
hv = self.node_encoder(node_feats)
he = edge_feats
for layer in range(self.num_layers):
hv1 = self.norms[layer](hv)
hv1 = F.relu(hv1)
hv1 = F.dropout(hv1, p=self.dropout, training=self.training)
hv = self.gcns[layer](g, hv1, he) + hv
h_g = self.pooling(g, hv)
return self.output(h_g)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Sequential):
r"""
Description
-----------
From equation (5) in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
"""
def __init__(self,
channels,
act='relu',
dropout=0.,
bias=True):
layers = []
for i in range(1, len(channels)):
layers.append(nn.Linear(channels[i - 1], channels[i], bias))
if i < len(channels) - 1:
layers.append(nn.BatchNorm1d(channels[i], affine=True))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout))
super(MLP, self).__init__(*layers)
class MessageNorm(nn.Module):
r"""
Description
-----------
Message normalization was introduced in "DeeperGCN: All You Need to Train Deeper GCNs <https://arxiv.org/abs/2006.07739>"
Parameters
----------
learn_scale: bool
Whether s is a learnable scaling factor or not. Default is False.
"""
def __init__(self, learn_scale=False):
super(MessageNorm, self).__init__()
self.scale = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=learn_scale)
def forward(self, feats, msg, p=2):
msg = F.normalize(msg, p=2, dim=-1)
feats_norm = feats.norm(p=p, dim=-1, keepdim=True)
return msg * feats_norm * self.scale
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment