Unverified Commit 2afa3598 authored by xnouhz's avatar xnouhz Committed by GitHub
Browse files

[Example] JKNet (#2795)



* [example] jknet

* update

* update

* update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 29801377
......@@ -88,6 +88,7 @@ The folder contains example implementations of selected research papers related
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |
| [Composition-based Multi-Relational Graph Convolutional Networks](#compgcn)| | :heavy_check_mark: | | | |
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
| [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | |
## 2021
......@@ -254,6 +255,10 @@ The folder contains example implementations of selected research papers related
- Example code: [pytorch](../examples/pytorch/seal)
- Tags: link prediction, sampling
- <a name="jknet"></a> Xu et al. Representation Learning on Graphs with Jumping Knowledge Networks. [Paper link](https://arxiv.org/abs/1806.03536).
- Example code: [pytorch](../examples/pytorch/jknet)
- Tags: message passing, neighborhood
## 2017
......
# DGL Implementation of JKNet
This DGL example implements the GNN model proposed in the paper [Representation Learning on Graphs with Jumping Knowledge Networks](https://arxiv.org/abs/1806.03536).
Contributor: [xnuohz](https://github.com/xnuohz)
### Requirements
The codebase is implemented in Python 3.6. For version requirement of packages, see below.
```
dgl 0.6.0
scikit-learn 0.24.1
tqdm 4.56.0
torch 1.7.1
```
### The graph datasets used in this example
###### Node Classification
The DGL's built-in Cora, Citeseer datasets. Dataset summary:
| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Cora | 2,708 | 10,556 | 1,433 | 7(single label) | 60% | 20% | 20% |
| Citeseer | 3,327 | 9,228 | 3,703 | 6(single label) | 60% | 20% | 20% |
### Usage
###### Dataset options
```
--dataset str The graph dataset name. Default is 'Cora'.
```
###### GPU options
```
--gpu int GPU index. Default is -1, using CPU.
```
###### Model options
```
--run int Number of running times. Default is 10.
--epochs int Number of training epochs. Default is 500.
--lr float Adam optimizer learning rate. Default is 0.01.
--lamb float L2 regularization coefficient. Default is 0.0005.
--hid-dim int Hidden layer dimensionalities. Default is 32.
--num-layers int Number of T. Default is 5.
--mode str Type of aggregation ['cat', 'max', 'lstm']. Default is 'cat'.
--dropout float Dropout applied at all layers. Default is 0.5.
```
###### Examples
The following commands learn a neural network and predict on the test set.
Train a JKNet which follows the original hyperparameters on different datasets.
```bash
# Cora:
python main.py --gpu 0 --mode max --num-layers 6
python main.py --gpu 0 --mode cat --num-layers 6
python main.py --gpu 0 --mode lstm --num-layers 1
# Citeseer:
python main.py --gpu 0 --dataset Citeseer --mode max --num-layers 1
python main.py --gpu 0 --dataset Citeseer --mode cat --num-layers 1
python main.py --gpu 0 --dataset Citeseer --mode lstm --num-layers 2
```
### Performance
**As the author does not release the code, we don't have the access to the data splits they used.**
###### Node Classification
* Cora
| | JK-Maxpool | JK-Concat | JK-LSTM |
| :-: | :-: | :-: | :-: |
| Metrics(Table 2) | 89.6±0.5 | 89.1±1.1 | 85.8±1.0 |
| Metrics(DGL) | 86.1±1.5 | 85.1±1.6 | 84.2±1.6 |
* Citeseer
| | JK-Maxpool | JK-Concat | JK-LSTM |
| :-: | :-: | :-: | :-: |
| Metrics(Table 2) | 77.7±0.5 | 78.3±0.8 | 74.7±0.9 |
| Metrics(DGL) | 70.9±1.9 | 73.0±1.5 | 69.0±1.7 |
\ No newline at end of file
""" The main file to train a JKNet model using a full graph """
import argparse
import copy
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from dgl.data import CoraGraphDataset, CiteseerGraphDataset
from tqdm import trange
from sklearn.model_selection import train_test_split
from model import JKNet
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset
if args.dataset == 'Cora':
dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer':
dataset = CiteseerGraphDataset()
else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset))
graph = dataset[0]
# check cuda
device = f'cuda:{args.gpu}' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu'
# retrieve the number of classes
n_classes = dataset.num_classes
# retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long()
# Extract node features
feats = graph.ndata.pop('feat').to(device)
n_features = feats.shape[-1]
# create masks for train / validation / test
# train : val : test = 6 : 2 : 2
n_nodes = graph.num_nodes()
idx = torch.arange(n_nodes).to(device)
train_idx, test_idx = train_test_split(idx, test_size=0.2)
train_idx, val_idx = train_test_split(train_idx, test_size=0.25)
graph = graph.to(device)
# Step 2: Create model =================================================================== #
model = JKNet(in_dim=n_features,
hid_dim=args.hid_dim,
out_dim=n_classes,
num_layers=args.num_layers,
mode=args.mode,
dropout=args.dropout).to(device)
best_model = copy.deepcopy(model)
# Step 3: Create training components ===================================================== #
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)
# Step 4: training epochs =============================================================== #
acc = 0
epochs = trange(args.epochs, desc='Accuracy & Loss')
for _ in epochs:
# Training using a full graph
model.train()
logits = model(graph, feats)
# compute loss
train_loss = loss_fn(logits[train_idx], labels[train_idx])
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
# backward
opt.zero_grad()
train_loss.backward()
opt.step()
# Validation using a full graph
model.eval()
with torch.no_grad():
valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
# Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()))
if valid_acc > acc:
acc = valid_acc
best_model = copy.deepcopy(model)
best_model.eval()
logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc))
return test_acc
if __name__ == "__main__":
"""
JKNet Hyperparameters
"""
parser = argparse.ArgumentParser(description='JKNet')
# data source params
parser.add_argument('--dataset', type=str, default='Cora', help='Name of dataset.')
# cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.')
# training params
parser.add_argument('--run', type=int, default=10, help='Running times.')
parser.add_argument('--epochs', type=int, default=500, help='Training epochs.')
parser.add_argument('--lr', type=float, default=0.005, help='Learning rate.')
parser.add_argument('--lamb', type=float, default=0.0005, help='L2 reg.')
# model params
parser.add_argument("--hid-dim", type=int, default=32, help='Hidden layer dimensionalities.')
parser.add_argument("--num-layers", type=int, default=5, help='Number of GCN layers.')
parser.add_argument("--mode", type=str, default='cat', help="Type of aggregation.", choices=['cat', 'max', 'lstm'])
parser.add_argument("--dropout", type=float, default=0.5, help='Dropout applied at all layers.')
args = parser.parse_args()
print(args)
acc_lists = []
for _ in range(args.run):
acc_lists.append(main(args))
mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
std = np.around(np.std(acc_lists, axis=0), decimals=3)
print('total acc: ', acc_lists)
print('mean', mean)
print('std', std)
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.conv import GraphConv
class JKNet(nn.Module):
def __init__(self,
in_dim,
hid_dim,
out_dim,
num_layers=1,
mode='cat',
dropout=0.):
super(JKNet, self).__init__()
self.mode = mode
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
self.layers.append(GraphConv(in_dim, hid_dim, activation=F.relu))
for _ in range(num_layers):
self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu))
if self.mode == 'cat':
hid_dim = hid_dim * (num_layers + 1)
elif self.mode == 'lstm':
self.lstm = nn.LSTM(hid_dim, (num_layers * hid_dim) // 2, bidirectional=True, batch_first=True)
self.attn = nn.Linear(2 * ((num_layers * hid_dim) // 2), 1)
self.output = nn.Linear(hid_dim, out_dim)
self.reset_params()
def reset_params(self):
self.output.reset_parameters()
for layers in self.layers:
layers.reset_parameters()
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.attn.reset_parameters()
def forward(self, g, feats):
feat_lst = []
for layer in self.layers:
feats = self.dropout(layer(g, feats))
feat_lst.append(feats)
if self.mode == 'cat':
out = torch.cat(feat_lst, dim=-1)
elif self.mode == 'max':
out = torch.stack(feat_lst, dim=-1).max(dim=-1)[0]
else:
# lstm
x = torch.stack(feat_lst, dim=1)
alpha, _ = self.lstm(x)
alpha = self.attn(alpha).squeeze(-1)
alpha = torch.softmax(alpha, dim=-1).unsqueeze(-1)
out = (x * alpha).sum(dim=1)
g.ndata['h'] = out
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
return self.output(g.ndata['h'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment