Unverified Commit 9a16a5e0 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Refactor GIN example (#4280)



* Refactor GIN example

* Update

* Update README

* Minor update

* README update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent d6957c28
...@@ -6,78 +6,28 @@ Graph Isomorphism Network (GIN) ...@@ -6,78 +6,28 @@ Graph Isomorphism Network (GIN)
Dependencies Dependencies
------------ ------------
- PyTorch 1.1.0+
- sklearn - sklearn
- tqdm
``bash Install as follows:
pip install torch sklearn tqdm
``
How to run
----------
An experiment on the GIN in default settings can be run with
```bash
python main.py
```
An experiment on the GIN in customized settings can be run with
```bash ```bash
python main.py [--device 0 | --disable-cuda] --dataset COLLAB \ pip install sklearn
--graph_pooling_type max --neighbor_pooling_type sum
``` ```
add `--degree_as_nlabel` to use one-hot encodings of node degrees as node feature vectors
Results How to run
------- -------
results may **fluctuate**, due to random factors and the relatively small data set. if you want to follow the paper's setting, consider the script below. Run with the following for bioinformatics graph classification (available datasets: MUTAG (default), PTC, NCI1, and PROTEINS)
```bash ```bash
# 4 bioinformatics datasets setting graph_pooling_type=sum, the nodes have categorical input features python3 train.py --dataset MUTAG
python main.py --dataset MUTAG --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum --filename MUTAG.txt
python main.py --dataset PTC --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum --filename PTC.txt
python main.py --dataset NCI1 --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum --filename NCI1.txt
python main.py --dataset PROTEINS --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum --filename PROTEINS.txt
# 5 social network datasets setting graph_pooling_type=mean, for the REDDIT datasets, we set all node feature vectors to be the same
# (thus, features here are uninformative); for the other social networks, we use one-hot encodings of node degrees.
python main.py --dataset COLLAB --device 0 \
--graph_pooling_type mean --neighbor_pooling_type sum --degree_as_nlabel --filename COLLAB.txt
python main.py --dataset IMDBBINARY --device 0 \
--graph_pooling_type mean --neighbor_pooling_type sum --degree_as_nlabel --filename IMDBBINARY.txt
python main.py --dataset IMDBMULTI --device 0 \
--graph_pooling_type mean --neighbor_pooling_type sum --degree_as_nlabel --filename IMDBMULTI.txt
python main.py --dataset REDDITBINARY --device 0 \
--graph_pooling_type mean --neighbor_pooling_type sum --filename REDDITBINARY.txt --fold_idx 6 --epoch 120
python main.py --dataset REDDITMULTI5K --device 0 \
--graph_pooling_type mean --neighbor_pooling_type sum --filename REDDITMULTI5K.txt
``` ```
one fold of 10 result are below. > **_NOTE:_** Users may observe results fluctuate due to the randomness with relatively small dataset. In consistence with the original [paper](https://arxiv.org/abs/1810.00826), five social network datasets, 'COLLAB', 'IMDBBINARY' 'IMDBMULTI' 'REDDITBINARY' and 'REDDITMULTI5K', are also available as the input. Users are encouraged to update the script slightly for social network applications, for example, replacing sum readout on bioinformatics datasets with mean readout on social network datasets and using one-hot encodings of node degrees by setting "degree_as_nlabel=True" in GINDataset.
| dataset | our result | paper report |
| ------------- | ---------- | ------------ |
| MUTAG | 89.4 | 89.4 ± 5.6 |
| PTC | 68.5 | 64.6 ± 7.0 |
| NCI1 | 78.5 | 82.7 ± 1.7 |
| PROTEINS | 72.3 | 76.2 ± 2.8 |
| COLLAB | 81.6 | 80.2 ± 1.9 |
| IMDBBINARY | 73.0 | 75.1 ± 5.1 |
| IMDBMULTI | 54.0 | 52.3 ± 2.8 |
| REDDITBINARY | 88.0 | 92.4 ± 2.5 |
| REDDITMULTI5K | 54.8 | 57.5 ± 1.5 |
Summary (10-fold cross-validation)
-------
| Dataset | Result
| ------------- | -------
| MUTAG | ~89.4
| PTC | ~68.5
| NCI1 | ~82.9
| PROTEINS | ~74.1
"""
PyTorch compatible dataloader
"""
import math
import numpy as np
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
import dgl
from dgl.dataloading import GraphDataLoader
class GINDataLoader():
def __init__(self,
dataset,
batch_size,
device,
collate_fn=None,
seed=0,
shuffle=True,
split_name='fold10',
fold_idx=0,
split_ratio=0.7):
self.shuffle = shuffle
self.seed = seed
self.kwargs = {'pin_memory': True} if 'cuda' in device.type else {}
labels = [l for _, l in dataset]
if split_name == 'fold10':
train_idx, valid_idx = self._split_fold10(
labels, fold_idx, seed, shuffle)
elif split_name == 'rand':
train_idx, valid_idx = self._split_rand(
labels, split_ratio, seed, shuffle)
else:
raise NotImplementedError()
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
self.train_loader = GraphDataLoader(
dataset, sampler=train_sampler,
batch_size=batch_size, collate_fn=collate_fn, **self.kwargs)
self.valid_loader = GraphDataLoader(
dataset, sampler=valid_sampler,
batch_size=batch_size, collate_fn=collate_fn, **self.kwargs)
def train_valid_loader(self):
return self.train_loader, self.valid_loader
def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True):
''' 10 flod '''
assert 0 <= fold_idx and fold_idx < 10, print(
"fold_idx must be from 0 to 9.")
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y)
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))
return train_idx, valid_idx
def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True):
num_entries = len(labels)
indices = list(range(num_entries))
np.random.seed(seed)
np.random.shuffle(indices)
split = int(math.floor(split_ratio * num_entries))
train_idx, valid_idx = indices[:split], indices[split:]
print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))
return train_idx, valid_idx
"""
How Powerful are Graph Neural Networks
https://arxiv.org/abs/1810.00826
https://openreview.net/forum?id=ryGs6iA5Km
Author's implementation: https://github.com/weihua916/powerful-gnns
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling
class ApplyNodeFunc(nn.Module):
"""Update the node feature hv with MLP, BN and ReLU."""
def __init__(self, mlp):
super(ApplyNodeFunc, self).__init__()
self.mlp = mlp
self.bn = nn.BatchNorm1d(self.mlp.output_dim)
def forward(self, h):
h = self.mlp(h)
h = self.bn(h)
h = F.relu(h)
return h
class MLP(nn.Module):
"""MLP with linear output"""
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
"""MLP layers construction
Paramters
---------
num_layers: int
The number of linear layers
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.output_dim = output_dim
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
for layer in range(num_layers - 2):
self.linears.append(nn.Linear(hidden_dim, hidden_dim, bias=False))
self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
for layer in range(num_layers - 1):
self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
def forward(self, x):
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for i in range(self.num_layers - 1):
h = F.relu(self.batch_norms[i](self.linears[i](h)))
return self.linears[-1](h)
class GIN(nn.Module):
"""GIN model"""
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
output_dim, final_dropout, learn_eps, graph_pooling_type,
neighbor_pooling_type):
"""model parameters setting
Paramters
---------
num_layers: int
The number of linear layers in the neural network
num_mlp_layers: int
The number of linear layers in mlps
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
final_dropout: float
dropout ratio on the final linear layer
learn_eps: boolean
If True, learn epsilon to distinguish center nodes from neighbors
If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: str
how to aggregate neighbors (sum, mean, or max)
graph_pooling_type: str
how to aggregate entire nodes in a graph (sum, mean or max)
"""
super(GIN, self).__init__()
self.num_layers = num_layers
self.learn_eps = learn_eps
# List of MLPs
self.ginlayers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers - 1):
if layer == 0:
mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
else:
mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)
self.ginlayers.append(
GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# Linear function for graph poolings of output of each layer
# which maps the output of different layers into a prediction score
self.linears_prediction = torch.nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linears_prediction.append(
nn.Linear(input_dim, output_dim))
else:
self.linears_prediction.append(
nn.Linear(hidden_dim, output_dim))
self.drop = nn.Dropout(final_dropout)
if graph_pooling_type == 'sum':
self.pool = SumPooling()
elif graph_pooling_type == 'mean':
self.pool = AvgPooling()
elif graph_pooling_type == 'max':
self.pool = MaxPooling()
else:
raise NotImplementedError
def forward(self, g, h):
# list of hidden representation at each layer (including input)
hidden_rep = [h]
for i in range(self.num_layers - 1):
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
score_over_layer = 0
# perform pooling over all nodes in each graph in every layer
for i, h in enumerate(hidden_rep):
pooled_h = self.pool(g, h)
score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
return score_over_layer
"""Parser for arguments
Put all arguments in one file and group similar arguments
"""
import argparse
class Parser():
def __init__(self, description):
'''
arguments parser
'''
self.parser = argparse.ArgumentParser(description=description)
self.args = None
self._parse()
def _parse(self):
# dataset
self.parser.add_argument(
'--dataset', type=str, default="MUTAG",
choices=['MUTAG', 'COLLAB', 'IMDBBINARY', 'IMDBMULTI', 'NCI1', 'PROTEINS', 'PTC', 'REDDITBINARY', 'REDDITMULTI5K'],
help='name of dataset (default: MUTAG)')
self.parser.add_argument(
'--batch_size', type=int, default=32,
help='batch size for training and validation (default: 32)')
self.parser.add_argument(
'--fold_idx', type=int, default=0,
help='the index(<10) of fold in 10-fold validation.')
self.parser.add_argument(
'--filename', type=str, default="",
help='output file')
self.parser.add_argument(
'--degree_as_nlabel', action="store_true",
help='use one-hot encodings of node degrees as node feature vectors')
# device
self.parser.add_argument(
'--disable-cuda', action='store_true',
help='Disable CUDA')
self.parser.add_argument(
'--device', type=int, default=0,
help='which gpu device to use (default: 0)')
# net
self.parser.add_argument(
'--num_layers', type=int, default=5,
help='number of layers (default: 5)')
self.parser.add_argument(
'--num_mlp_layers', type=int, default=2,
help='number of MLP layers(default: 2). 1 means linear model.')
self.parser.add_argument(
'--hidden_dim', type=int, default=64,
help='number of hidden units (default: 64)')
# graph
self.parser.add_argument(
'--graph_pooling_type', type=str,
default="sum", choices=["sum", "mean", "max"],
help='type of graph pooling: sum, mean or max')
self.parser.add_argument(
'--neighbor_pooling_type', type=str,
default="sum", choices=["sum", "mean", "max"],
help='type of neighboring pooling: sum, mean or max')
self.parser.add_argument(
'--learn_eps', action="store_true",
help='learn the epsilon weighting')
# learning
self.parser.add_argument(
'--seed', type=int, default=0,
help='random seed (default: 0)')
self.parser.add_argument(
'--epochs', type=int, default=350,
help='number of epochs to train (default: 350)')
self.parser.add_argument(
'--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
self.parser.add_argument(
'--final_dropout', type=float, default=0.5,
help='final layer dropout (default: 0.5)')
# done
self.args = self.parser.parse_args()
import sys
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from dgl.data import GINDataset
from dataloader import GINDataLoader
from ginparser import Parser
from gin import GIN
def train(args, net, trainloader, optimizer, criterion, epoch):
net.train()
running_loss = 0
total_iters = len(trainloader)
# setup the offset to avoid the overlap with mouse cursor
bar = tqdm(range(total_iters), unit='batch', position=2, file=sys.stdout)
for pos, (graphs, labels) in zip(bar, trainloader):
# batch graphs will be shipped to device in forward part of model
labels = labels.to(args.device)
graphs = graphs.to(args.device)
feat = graphs.ndata.pop('attr')
outputs = net(graphs, feat)
loss = criterion(outputs, labels)
running_loss += loss.item()
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
# report
bar.set_description('epoch-{}'.format(epoch))
bar.close()
# the final batch will be aligned
running_loss = running_loss / total_iters
return running_loss
def eval_net(args, net, dataloader, criterion):
net.eval()
total = 0
total_loss = 0
total_correct = 0
for data in dataloader:
graphs, labels = data
graphs = graphs.to(args.device)
labels = labels.to(args.device)
feat = graphs.ndata.pop('attr')
total += len(labels)
outputs = net(graphs, feat)
_, predicted = torch.max(outputs.data, 1)
total_correct += (predicted == labels.data).sum().item()
loss = criterion(outputs, labels)
# crossentropy(reduce=True) for default
total_loss += loss.item() * len(labels)
loss, acc = 1.0*total_loss / total, 1.0*total_correct / total
net.train()
return loss, acc
def main(args):
# set up seeds, args.seed supported
torch.manual_seed(seed=args.seed)
np.random.seed(seed=args.seed)
is_cuda = not args.disable_cuda and torch.cuda.is_available()
if is_cuda:
args.device = torch.device("cuda:" + str(args.device))
torch.cuda.manual_seed_all(seed=args.seed)
else:
args.device = torch.device("cpu")
dataset = GINDataset(args.dataset, not args.learn_eps, args.degree_as_nlabel)
trainloader, validloader = GINDataLoader(
dataset, batch_size=args.batch_size, device=args.device,
seed=args.seed, shuffle=True,
split_name='fold10', fold_idx=args.fold_idx).train_valid_loader()
# or split_name='rand', split_ratio=0.7
model = GIN(
args.num_layers, args.num_mlp_layers,
dataset.dim_nfeats, args.hidden_dim, dataset.gclasses,
args.final_dropout, args.learn_eps,
args.graph_pooling_type, args.neighbor_pooling_type).to(args.device)
criterion = nn.CrossEntropyLoss() # defaul reduce is true
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# it's not cost-effective to hanle the cursor and init 0
# https://stackoverflow.com/a/23121189
tbar = tqdm(range(args.epochs), unit="epoch", position=3, ncols=0, file=sys.stdout)
vbar = tqdm(range(args.epochs), unit="epoch", position=4, ncols=0, file=sys.stdout)
lrbar = tqdm(range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)
for epoch, _, _ in zip(tbar, vbar, lrbar):
train(args, model, trainloader, optimizer, criterion, epoch)
scheduler.step()
train_loss, train_acc = eval_net(
args, model, trainloader, criterion)
tbar.set_description(
'train set - average loss: {:.4f}, accuracy: {:.0f}%'
.format(train_loss, 100. * train_acc))
valid_loss, valid_acc = eval_net(
args, model, validloader, criterion)
vbar.set_description(
'valid set - average loss: {:.4f}, accuracy: {:.0f}%'
.format(valid_loss, 100. * valid_acc))
if not args.filename == "":
with open(args.filename, 'a') as f:
f.write('%s %s %s %s %s' % (
args.dataset,
args.learn_eps,
args.neighbor_pooling_type,
args.graph_pooling_type,
epoch
))
f.write("\n")
f.write("%f %f %f %f" % (
train_loss,
train_acc,
valid_loss,
valid_acc
))
f.write("\n")
lrbar.set_description(
"Learning eps with learn_eps={}: {}".format(
args.learn_eps, [layer.eps.data.item() for layer in model.ginlayers]))
tbar.close()
vbar.close()
lrbar.close()
if __name__ == '__main__':
args = Parser(description='GIN').args
print('show all arguments configuration...')
print(args)
main(args)
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling
import argparse
class MLP(nn.Module):
"""Construct two-layer MLP-type aggreator for GIN model"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.linears = nn.ModuleList()
# two-layer MLP
self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
self.batch_norm = nn.BatchNorm1d((hidden_dim))
def forward(self, x):
h = x
h = F.relu(self.batch_norm(self.linears[0](h)))
return self.linears[1](h)
class GIN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.ginlayers = nn.ModuleList()
self.batch_norms = nn.ModuleList()
num_layers = 5
# five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
for layer in range(num_layers - 1): # excluding the input layer
if layer == 0:
mlp = MLP(input_dim, hidden_dim, hidden_dim)
else:
mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
self.ginlayers.append(GINConv(mlp, learn_eps=False)) # set to True if learning epsilon
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# linear functions for graph sum poolings of output of each layer
self.linear_prediction = nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linear_prediction.append(nn.Linear(input_dim, output_dim))
else:
self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
self.drop = nn.Dropout(0.5)
self.pool = SumPooling() # change to mean readout (AvgPooling) on social network datasets
def forward(self, g, h):
# list of hidden representation at each layer (including the input layer)
hidden_rep = [h]
for i, layer in enumerate(self.ginlayers):
h = layer(g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
score_over_layer = 0
# perform graph sum pooling over all nodes in each layer
for i, h in enumerate(hidden_rep):
pooled_h = self.pool(g, h)
score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
return score_over_layer
def split_fold10(labels, fold_idx=0):
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels):
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
return train_idx, valid_idx
def evaluate(dataloader, device, model):
model.eval()
total = 0
total_correct = 0
for batched_graph, labels in dataloader:
batched_graph = batched_graph.to(device)
labels = labels.to(device)
feat = batched_graph.ndata.pop('attr')
total += len(labels)
logits = model(batched_graph, feat)
_, predicted = torch.max(logits, 1)
total_correct += (predicted == labels).sum().item()
acc = 1.0 * total_correct / total
return acc
def train(train_loader, val_loader, device, model):
# loss function, optimizer and scheduler
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# training loop
for epoch in range(350):
model.train()
total_loss = 0
for batch, (batched_graph, labels) in enumerate(train_loader):
batched_graph = batched_graph.to(device)
labels = labels.to(device)
feat = batched_graph.ndata.pop('attr')
logits = model(batched_graph, feat)
loss = loss_fcn(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
train_acc = evaluate(train_loader, device, model)
valid_acc = evaluate(val_loader, device, model)
print("Epoch {:05d} | Loss {:.4f} | Train Acc. {:.4f} | Validation Acc. {:.4f} "
. format(epoch, total_loss / (batch + 1), train_acc, valid_acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default="MUTAG",
choices=['MUTAG', 'PTC', 'NCI1', 'PROTEINS'],
help='name of dataset (default: MUTAG)')
args = parser.parse_args()
print(f'Training with DGL built-in GINConv module with a fixed epsilon = 0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load and split dataset
dataset = GINDataset(args.dataset, self_loop=True, degree_as_nlabel=False) # add self_loop and disable one-hot encoding for input features
labels = [l for _, l in dataset]
train_idx, val_idx = split_fold10(labels)
# create dataloader
train_loader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(train_idx),
batch_size=128, pin_memory=torch.cuda.is_available())
val_loader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(val_idx),
batch_size=128, pin_memory=torch.cuda.is_available())
# create GIN model
in_size = dataset.dim_nfeats
out_size = dataset.gclasses
model = GIN(in_size, 16, out_size).to(device)
# model training/validating
print('Training...')
train(train_loader, val_loader, device, model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment