Unverified Commit 80fdcfdf authored by 张恒瑞's avatar 张恒瑞 Committed by GitHub
Browse files

[Example] Graph Random Neural Network (#2502)



* Add Implementation of GRAND

* Performance updated

* Update README.md

add indexing entries

* add indexing entries for grand

* change grand from 2021 to 2020

* [Example] fix some bugs

* [Examples] fix bug
Co-authored-by: default avatar张恒瑞 <hengrui@macbook.local>
Co-authored-by: default avatar张恒瑞 <hengrui@192.168.124.35>
Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent 1502c56d
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
## Overview ## Overview
| Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB | | Paper | node classification | link prediction / classification | graph property prediction | sampling | OGB |
| ------------------------------------------------------------------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ | | ------------------------------------------------------------ | ------------------- | -------------------------------- | ------------------------- | ------------------ | ------------------ |
| [Graph Random Neural Network for Semi-Supervised Learning on Graphs](#grand) | :heavy_check_mark: | | | | |
| [Heterogeneous Graph Transformer](#hgt) | :heavy_check_mark: | :heavy_check_mark: | | | | | [Heterogeneous Graph Transformer](#hgt) | :heavy_check_mark: | :heavy_check_mark: | | | |
| [Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges](#mwe) | :heavy_check_mark: | | | | :heavy_check_mark: | | [Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges](#mwe) | :heavy_check_mark: | | | | :heavy_check_mark: |
| [SIGN: Scalable Inception Graph Neural Networks](#sign) | :heavy_check_mark: | | | | :heavy_check_mark: | | [SIGN: Scalable Inception Graph Neural Networks](#sign) | :heavy_check_mark: | | | | :heavy_check_mark: |
...@@ -40,18 +41,18 @@ ...@@ -40,18 +41,18 @@
## 2020 ## 2020
- <a name="grand"></a> Feng et al. Graph Random Neural Network for Semi-Supervised Learning on Graphs. [Paper link](https://arxiv.org/abs/2005.11079).
- Example code: [PyTorch](../examples/pytorch/grand)
- Tags: semi-supervised node classification, simplifying graph convolution, data augmentation
- <a name="hgt"></a> Hu et al. Heterogeneous Graph Transformer. [Paper link](https://arxiv.org/abs/2003.01332). - <a name="hgt"></a> Hu et al. Heterogeneous Graph Transformer. [Paper link](https://arxiv.org/abs/2003.01332).
- Example code: [PyTorch](../examples/pytorch/hgt) - Example code: [PyTorch](../examples/pytorch/hgt)
- Tags: dynamic heterogeneous graphs, large-scale, node classification, link prediction - Tags: dynamic heterogeneous graphs, large-scale, node classification, link prediction
- <a name="mwe"></a> Chen. Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges. [Paper link](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf). - <a name="mwe"></a> Chen. Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges. [Paper link](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf).
- Example code: [PyTorch on ogbn-proteins](../examples/pytorch/ogb/ogbn-proteins) - Example code: [PyTorch on ogbn-proteins](../examples/pytorch/ogb/ogbn-proteins)
- Tags: node classification, weighted graphs, OGB - Tags: node classification, weighted graphs, OGB
- <a name="sign"></a> Frasca et al. SIGN: Scalable Inception Graph Neural Networks. [Paper link](https://arxiv.org/abs/2004.11198). - <a name="sign"></a> Frasca et al. SIGN: Scalable Inception Graph Neural Networks. [Paper link](https://arxiv.org/abs/2004.11198).
- Example code: [PyTorch on ogbn-arxiv/products/mag](../examples/pytorch/ogb/sign), [PyTorch](../examples/pytorch/sign) - Example code: [PyTorch on ogbn-arxiv/products/mag](../examples/pytorch/ogb/sign), [PyTorch](../examples/pytorch/sign)
- Tags: node classification, OGB, large-scale, heterogeneous graphs - Tags: node classification, OGB, large-scale, heterogeneous graphs
- <a name="prestrategy"></a> Hu et al. Strategies for Pre-training Graph Neural Networks. [Paper link](https://arxiv.org/abs/1905.12265). - <a name="prestrategy"></a> Hu et al. Strategies for Pre-training Graph Neural Networks. [Paper link](https://arxiv.org/abs/1905.12265).
- Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration) - Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
- Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction - Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction
......
# Graph Random Neural Network(GRAND)
This DGL example implements the GNN model proposed in the paper [Graph Random Neural Network for Semi-Supervised Learning on Graphs]( https://arxiv.org/abs/2005.11079).
Paper link: https://arxiv.org/abs/2005.11079
Author's code: https://github.com/THUDM/GRAND
Contributor: Hengrui Zhang ([@hengruizhang98](https://github.com/hengruizhang98))
## Dependecies
- Python 3.7
- PyTorch 1.7.1
- numpy
- dgl 0.5.3
## Dataset
The DGL's built-in Cora, Pubmed and Citeseer datasets. Dataset summary:
| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |
| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |
| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |
## Arguments
###### Dataset options
```
--dataname str The graph dataset name. Default is 'cora'.
```
###### GPU options
```
--gpu int GPU index. Default is -1, using CPU.
```
###### Model options
```
--epochs int Number of training epochs. Default is 2000.
--early_stopping int Early stopping patience rounds. Default is 200.
--lr float Adam optimizer learning rate. Default is 0.01.
--weight_decay float L2 regularization coefficient. Default is 5e-4.
--dropnode_rate float Dropnode rate (1 - keep probability). Default is 0.5.
--input_droprate float Dropout rate of input layer. Default is 0.5.
--hidden_droprate float Dropout rate of hidden layer. Default is 0.5.
--hid_dim int Hidden layer dimensionalities. Default is 32.
--order int Propagation step. Default is 8.
--sample int Sampling times of dropnode. Default is 4.
--tem float Sharpening temperaturer. Default is 0.5.
--lam float Coefficient of Consistency reg Default is 1.0.
--use_bn bool Using batch normalization. Default is False
```
## Examples
Train a model which follows the original hyperparameters on different datasets.
```bash
# Cora:
python main.py --dataname cora --gpu 0 --lam 1.0 --tem 0.5 --order 8 --sample 4 --input_droprate 0.5 --hidden_droprate 0.5 --dropnode_rate 0.5 --hid_dim 32 --early_stopping 100 --lr 1e-2 --epochs 2000
# Citeseer:
python main.py --dataname citeseer --gpu 0 --lam 0.7 --tem 0.3 --order 2 --sample 2 --input_droprate 0.0 --hidden_droprate 0.2 --dropnode_rate 0.5 --hid_dim 32 --early_stopping 100 --lr 1e-2 --epochs 2000
# Pubmed:
python main.py --dataname pubmed --gpu 0 --lam 1.0 --tem 0.2 --order 5 --sample 4 --input_droprate 0.6 --hidden_droprate 0.8 --dropnode_rate 0.5 --hid_dim 32 --early_stopping 200 --lr 0.2 --epochs 2000 --use_bn
```
### Performance
The hyperparameter setting in our implementation is identical to that reported in the paper.
| Dataset | Cora | Citeseer | Pubmed |
| :-: | :-: | :-: | :-: |
| Accuracy Reported(100 runs) | **85.4(±0.4)** | **75.4(±0.4)** | 82.7(±0.6) |
| Accuracy DGL(20 runs) | 85.33(±0.41) | 75.36(±0.36) | **82.90(±0.66)** |
import argparse
import numpy as np
import torch as th
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from model import GRAND
import warnings
warnings.filterwarnings('ignore')
def argument():
parser = argparse.ArgumentParser(description='GRAND')
# data source params
parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.')
# cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.')
# training params
parser.add_argument('--epochs', type=int, default=200, help='Training epochs.')
parser.add_argument('--early_stopping', type=int, default=200, help='Patient epochs to wait before early stopping.')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='L2 reg.')
# model params
parser.add_argument("--hid_dim", type=int, default=32, help='Hidden layer dimensionalities.')
parser.add_argument('--dropnode_rate', type=float, default=0.5,
help='Dropnode rate (1 - keep probability).')
parser.add_argument('--input_droprate', type=float, default=0.0,
help='dropout rate of input layer')
parser.add_argument('--hidden_droprate', type=float, default=0.0,
help='dropout rate of hidden layer')
parser.add_argument('--order', type=int, default=8, help='Propagation step')
parser.add_argument('--sample', type=int, default=4, help='Sampling times of dropnode')
parser.add_argument('--tem', type=float, default=0.5, help='Sharpening temperature')
parser.add_argument('--lam', type=float, default=1., help='Coefficient of consistency regularization')
parser.add_argument('--use_bn', action='store_true', default=False, help='Using Batch Normalization')
args = parser.parse_args()
# check cuda
if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu)
else:
args.device = 'cpu'
return args
def consis_loss(logps, temp, lam):
ps = [th.exp(p) for p in logps]
ps = th.stack(ps, dim = 2)
avg_p = th.mean(ps, dim = 2)
sharp_p = (th.pow(avg_p, 1./temp) / th.sum(th.pow(avg_p, 1./temp), dim=1, keepdim=True)).detach()
sharp_p = sharp_p.unsqueeze(2)
loss = th.mean(th.sum(th.pow(ps - sharp_p, 1./temp), dim = 1, keepdim=True))
loss = lam * loss
return loss
if __name__ == '__main__':
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset
args = argument()
print(args)
if args.dataname == 'cora':
dataset = CoraGraphDataset()
elif args.dataname == 'citeseer':
dataset = CiteseerGraphDataset()
elif args.dataname == 'pubmed':
dataset = PubmedGraphDataset()
graph = dataset[0]
graph = dgl.add_self_loop(graph)
device = args.device
# retrieve the number of classes
n_classes = dataset.num_classes
# retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long()
# Extract node features
feats = graph.ndata.pop('feat').to(device)
n_features = feats.shape[-1]
# retrieve masks for train/validation/test
train_mask = graph.ndata.pop('train_mask')
val_mask = graph.ndata.pop('val_mask')
test_mask = graph.ndata.pop('test_mask')
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze().to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze().to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze().to(device)
# Step 2: Create model =================================================================== #
model = GRAND(n_features, args.hid_dim, n_classes, args.sample, args.order,
args.dropnode_rate, args.input_droprate,
args.hidden_droprate, args.use_bn)
model = model.to(args.device)
graph = graph.to(args.device)
# Step 3: Create training components ===================================================== #
loss_fn = nn.NLLLoss()
opt = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.weight_decay)
loss_best = np.inf
acc_best = 0
# Step 4: training epoches =============================================================== #
for epoch in range(args.epochs):
''' Training '''
model.train()
loss_sup = 0
logits = model(graph, feats, True)
# calculate supervised loss
for k in range(args.sample):
loss_sup += F.nll_loss(logits[k][train_idx], labels[train_idx])
loss_sup = loss_sup/args.sample
# calculate consistency loss
loss_consis = consis_loss(logits, args.tem, args.lam)
loss_train = loss_sup + loss_consis
acc_train = th.sum(logits[0][train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
# backward
opt.zero_grad()
loss_train.backward()
opt.step()
''' Validating '''
model.eval()
with th.no_grad():
val_logits = model(graph, feats, False)
loss_val = F.nll_loss(val_logits[val_idx], labels[val_idx])
acc_val = th.sum(val_logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
# Print out performance
print("In epoch {}, Train Acc: {:.4f} | Train Loss: {:.4f} ,Val Acc: {:.4f} | Val Loss: {:.4f}".
format(epoch, acc_train, loss_train.item(), acc_val, loss_val.item()))
# set early stopping counter
if loss_val < loss_best or acc_val > acc_best:
if loss_val < loss_best:
best_epoch = epoch
th.save(model.state_dict(), args.dataname +'.pkl')
no_improvement = 0
loss_best = min(loss_val, loss_best)
acc_best = max(acc_val, acc_best)
else:
no_improvement += 1
if no_improvement == args.early_stopping:
print('Early stopping.')
break
print("Optimization Finished!")
print('Loading {}th epoch'.format(best_epoch))
model.load_state_dict(th.load(args.dataname +'.pkl'))
''' Testing '''
model.eval()
test_logits = model(graph, feats, False)
test_acc = th.sum(test_logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc: {:.4f}".format(test_acc))
import numpy as np
import torch as th
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
def drop_node(feats, drop_rate, training):
n = feats.shape[0]
drop_rates = th.FloatTensor(np.ones(n) * drop_rate)
if training:
masks = th.bernoulli(1. - drop_rates).unsqueeze(1)
feats = masks.to(feats.device) * feats
else:
feats = feats * (1. - drop_rate)
return feats
class MLP(nn.Module):
def __init__(self, nfeat, nhid, nclass, input_droprate, hidden_droprate, use_bn =False):
super(MLP, self).__init__()
self.layer1 = nn.Linear(nfeat, nhid, bias = True)
self.layer2 = nn.Linear(nhid, nclass, bias = True)
self.input_dropout = nn.Dropout(input_droprate)
self.hidden_dropout = nn.Dropout(hidden_droprate)
self.bn1 = nn.BatchNorm1d(nfeat)
self.bn2 = nn.BatchNorm1d(nhid)
self.use_bn = use_bn
def reset_parameters(self):
self.layer1.reset_parameters()
self.layer2.reset_parameters()
def forward(self, x):
if self.use_bn:
x = self.bn1(x)
x = self.input_dropout(x)
x = F.relu(self.layer1(x))
if self.use_bn:
x = self.bn2(x)
x = self.hidden_dropout(x)
x = self.layer2(x)
return x
def GRANDConv(graph, feats, order):
'''
Parameters
-----------
graph: dgl.Graph
The input graph
feats: Tensor (n_nodes * feat_dim)
Node features
order: int
Propagation Steps
'''
with graph.local_scope():
''' Calculate Symmetric normalized adjacency matrix \hat{A} '''
degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1)
graph.ndata['norm'] = norm
graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight'))
''' Graph Conv '''
x = feats
y = 0+feats
for i in range(order):
graph.ndata['h'] = x
graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h'))
x = graph.ndata.pop('h')
y.add_(x)
return y /(order + 1)
class GRAND(nn.Module):
r"""
Parameters
-----------
in_dim: int
Input feature size. i.e, the number of dimensions of: math: `H^{(i)}`.
hid_dim: int
Hidden feature size.
n_class: int
Number of classes.
S: int
Number of Augmentation samples
K: int
Number of Propagation Steps
node_dropout: float
Dropout rate on node features.
input_dropout: float
Dropout rate of the input layer of a MLP
hidden_dropout: float
Dropout rate of the hidden layer of a MLPx
batchnorm: bool, optional
If True, use batch normalization.
"""
def __init__(self,
in_dim,
hid_dim,
n_class,
S = 1,
K = 3,
node_dropout=0.0,
input_droprate = 0.0,
hidden_droprate = 0.0,
batchnorm=False):
super(GRAND, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
self.S = S
self.K = K
self.n_class = n_class
self.mlp = MLP(in_dim, hid_dim, n_class, input_droprate, hidden_droprate, batchnorm)
self.dropout = node_dropout
self.node_dropout = nn.Dropout(node_dropout)
def forward(self, graph, feats, training = True):
X = feats
S = self.S
if training: # Training Mode
output_list = []
for s in range(S):
drop_feat = drop_node(X, self.dropout, True) # Drop node
feat = GRANDConv(graph, drop_feat, self.K) # Graph Convolution
output_list.append(th.log_softmax(self.mlp(feat), dim=-1)) # Prediction
return output_list
else: # Inference Mode
drop_feat = drop_node(X, self.dropout, False)
X = GRANDConv(graph, drop_feat, self.K)
return th.log_softmax(self.mlp(X), dim = -1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment