Unverified Commit 39764da4 authored by Kay Liu's avatar Kay Liu Committed by GitHub
Browse files

[model] add model example GeniePath (#3199)



* [model] add model example GeniePath

* improvements based on feedback

* improvements based on feedback
Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent a5dc230a
......@@ -165,6 +165,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name='gas'></a> Li A, Qin Z, et al. Spam Review Detection with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1908.10679).
- Example code: [PyTorch](../examples/pytorch/gas)
- Tags: Fraud detection, Heterogeneous graph, Edge classification, Graph attention
- <a name='geniepath'></a> Liu Z, et al. Geniepath: Graph neural networks with adaptive receptive paths. [Paper link](https://arxiv.org/abs/1802.00910).
- Example code: [PyTorch](../examples/pytorch/geniepath)
- Tags: Fraud detection, Node classification, Graph attention, LSTM, Adaptive receptive fields
## 2018
......
# DGL Implementation of the GeniePath Paper
This DGL example implements the GNN model proposed in the paper [GeniePath: Graph Neural Networks with Adaptive Receptive Paths](https://arxiv.org/abs/1802.00910).
Example implementor
----------------------
This example was implemented by [Kay Liu](https://github.com/kayzliu) during his SDE intern work at the AWS Shanghai AI Lab.
Dependencies
----------------------
- Python 3.7.10
- PyTorch 1.8.1
- dgl 0.7.0
- scikit-learn 0.23.2
Dataset
---------------------------------------
The datasets used for node classification are [Pubmed citation network dataset](https://docs.dgl.ai/api/python/dgl.data.html#dgl.data.PubmedGraphDataset) (tranductive) and [Protein-Protein Interaction dataset](https://docs.dgl.ai/api/python/dgl.data.html#dgl.data.PPIDataset) (inductive).
How to run
--------------------------------
If want to train on Pubmed (transductive), run
```
python pubmed.py
```
If want to use a GPU, run
```
python pubmed.py --gpu 0
```
If want to train GeniePath-Lazy, run
```
python pubmed.py --lazy True
```
If want to train on PPI (inductive), run
```
python ppi.py
```
Performance
-------------------------
|Dataset | Pubmed (ACC)| PPI (micro-F1)|
| ------ | ----------- | ------------- |
| Paper | 78.5% | 0.952 |
| DGL | 73.0% | 0.959 |
\ No newline at end of file
import torch as th
import torch.nn as nn
from dgl.nn import GATConv
from torch.nn import LSTM
class GeniePathConv(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False):
super(GeniePathConv, self).__init__()
self.breadth_func = GATConv(in_dim, hid_dim, num_heads=num_heads, residual=residual)
self.depth_func = LSTM(hid_dim, out_dim)
def forward(self, graph, x, h, c):
x = self.breadth_func(graph, x)
x = th.mean(x, dim=1)
x, (h, c) = self.depth_func(x.unsqueeze(0), (h, c))
x = x[0]
return x, (h, c)
class GeniePath(nn.Module):
def __init__(self, in_dim, out_dim, hid_dim=16, num_layers=2, num_heads=1, residual=False):
super(GeniePath, self).__init__()
self.hid_dim = hid_dim
self.linear1 = nn.Linear(in_dim, hid_dim)
self.linear2 = nn.Linear(hid_dim, out_dim)
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(GeniePathConv(hid_dim, hid_dim, hid_dim, num_heads=num_heads, residual=residual))
def forward(self, graph, x):
h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
x = self.linear1(x)
for layer in self.layers:
x, (h, c) = layer(graph, x, h, c)
x = self.linear2(x)
return x
class GeniePathLazy(nn.Module):
def __init__(self, in_dim, out_dim, hid_dim=16, num_layers=2, num_heads=1, residual=False):
super(GeniePathLazy, self).__init__()
self.hid_dim = hid_dim
self.linear1 = nn.Linear(in_dim, hid_dim)
self.linear2 = th.nn.Linear(hid_dim, out_dim)
self.breaths = nn.ModuleList()
self.depths = nn.ModuleList()
for i in range(num_layers):
self.breaths.append(GATConv(hid_dim, hid_dim, num_heads=num_heads, residual=residual))
self.depths.append(LSTM(hid_dim*2, hid_dim))
def forward(self, graph, x):
h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
c = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
x = self.linear1(x)
h_tmps = []
for layer in self.breaths:
h_tmps.append(th.mean(layer(graph, x), dim=1))
x = x.unsqueeze(0)
for h_tmp, layer in zip(h_tmps, self.depths):
in_cat = th.cat((h_tmp.unsqueeze(0), x), -1)
x, (h, c) = layer(in_cat, (h, c))
x = self.linear2(x[0])
return x
import argparse
import numpy as np
import torch as th
import torch.optim as optim
from dgl.data import PPIDataset
from dgl.dataloading import GraphDataLoader
from sklearn.metrics import f1_score
from model import GeniePath, GeniePathLazy
def evaluate(model, loss_fn, dataloader, device='cpu'):
loss = 0
f1 = 0
num_blocks = 0
for subgraph in dataloader:
subgraph = subgraph.to(device)
label = subgraph.ndata['label'].to(device)
feat = subgraph.ndata['feat']
logits = model(subgraph, feat)
# compute loss
loss += loss_fn(logits, label).item()
predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0)
f1 += f1_score(label.cpu(), predict, average='micro')
num_blocks += 1
return f1 / num_blocks, loss / num_blocks
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset
train_dataset = PPIDataset(mode='train')
valid_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')
train_dataloader = GraphDataLoader(train_dataset, batch_size=args.batch_size)
valid_dataloader = GraphDataLoader(valid_dataset, batch_size=args.batch_size)
test_dataloader = GraphDataLoader(test_dataset, batch_size=args.batch_size)
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
num_classes = train_dataset.num_labels
# Extract node features
graph = train_dataset[0]
feat = graph.ndata['feat']
# Step 2: Create model =================================================================== #
if args.lazy:
model = GeniePathLazy(in_dim=feat.shape[-1],
out_dim=num_classes,
hid_dim=args.hid_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
residual=args.residual)
else:
model = GeniePath(in_dim=feat.shape[-1],
out_dim=num_classes,
hid_dim=args.hid_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
residual=args.residual)
model = model.to(device)
# Step 3: Create training components ===================================================== #
loss_fn = th.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch):
model.train()
tr_loss = 0
tr_f1 = 0
num_blocks = 0
for subgraph in train_dataloader:
subgraph = subgraph.to(device)
label = subgraph.ndata['label']
feat = subgraph.ndata['feat']
logits = model(subgraph, feat)
# compute loss
batch_loss = loss_fn(logits, label)
tr_loss += batch_loss.item()
tr_predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0)
tr_f1 += f1_score(label.cpu(), tr_predict, average='micro')
num_blocks += 1
# backward
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
# validation
model.eval()
val_f1, val_loss = evaluate(model, loss_fn, valid_dataloader, device)
print("In epoch {}, Train F1: {:.4f} | Train Loss: {:.4f}; Valid F1: {:.4f} | Valid loss: {:.4f}".
format(epoch, tr_f1 / num_blocks, tr_loss / num_blocks, val_f1, val_loss))
# Test after all epoch
model.eval()
test_f1, test_loss = evaluate(model, loss_fn, test_dataloader, device)
print("Test F1: {:.4f} | Test loss: {:.4f}".
format(test_f1, test_loss))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GeniePath')
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--hid_dim", type=int, default=256, help="Hidden layer dimension")
parser.add_argument("--num_layers", type=int, default=3, help="Number of GeniePath layers")
parser.add_argument("--max_epoch", type=int, default=1000, help="The max number of epochs. Default: 1000")
parser.add_argument("--lr", type=float, default=0.0004, help="Learning rate. Default: 0.0004")
parser.add_argument("--num_heads", type=int, default=1, help="Number of head in breadth function. Default: 1")
parser.add_argument("--residual", type=bool, default=False, help="Residual in GAT or not")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size of graph dataloader")
parser.add_argument("--lazy", type=bool, default=False, help="Variant GeniePath-Lazy")
args = parser.parse_args()
print(args)
th.manual_seed(16)
main(args)
import argparse
import torch as th
import torch.optim as optim
from dgl.data import PubmedGraphDataset
from sklearn.metrics import accuracy_score
from model import GeniePath, GeniePathLazy
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset
dataset = PubmedGraphDataset()
graph = dataset[0]
# check cuda
if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
else:
device = 'cpu'
num_classes = dataset.num_classes
# retrieve label of ground truth
label = graph.ndata['label'].to(device)
# Extract node features
feat = graph.ndata['feat'].to(device)
# retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze(1).to(device)
graph = graph.to(device)
# Step 2: Create model =================================================================== #
if args.lazy:
model = GeniePathLazy(in_dim=feat.shape[-1],
out_dim=num_classes,
hid_dim=args.hid_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
residual=args.residual)
else:
model = GeniePath(in_dim=feat.shape[-1],
out_dim=num_classes,
hid_dim=args.hid_dim,
num_layers=args.num_layers,
num_heads=args.num_heads,
residual=args.residual)
model = model.to(device)
# Step 3: Create training components ===================================================== #
loss_fn = th.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch):
# Training and validation
model.train()
logits = model(graph, feat)
# compute loss
tr_loss = loss_fn(logits[train_idx], label[train_idx])
tr_acc = accuracy_score(label[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu())
# validation
valid_loss = loss_fn(logits[val_idx], label[val_idx])
valid_acc = accuracy_score(label[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu())
# backward
optimizer.zero_grad()
tr_loss.backward()
optimizer.step()
# Print out performance
print("In epoch {}, Train ACC: {:.4f} | Train Loss: {:.4f}; Valid ACC: {:.4f} | Valid loss: {:.4f}".
format(epoch, tr_acc, tr_loss.item(), valid_acc, valid_loss.item()))
# Test after all epoch
model.eval()
# forward
logits = model(graph, feat)
# compute loss
test_loss = loss_fn(logits[test_idx], label[test_idx])
test_acc = accuracy_score(label[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu())
print("Test ACC: {:.4f} | Test loss: {:.4f}".
format(test_acc, test_loss.item()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GeniePath')
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
parser.add_argument("--hid_dim", type=int, default=16, help="Hidden layer dimension")
parser.add_argument("--num_layers", type=int, default=2, help="Number of GeniePath layers")
parser.add_argument("--max_epoch", type=int, default=300, help="The max number of epochs. Default: 300")
parser.add_argument("--lr", type=float, default=0.0004, help="Learning rate. Default: 0.0004")
parser.add_argument("--num_heads", type=int, default=1, help="Number of head in breadth function. Default: 1")
parser.add_argument("--residual", type=bool, default=False, help="Residual in GAT or not")
parser.add_argument("--lazy", type=bool, default=False, help="Variant GeniePath-Lazy")
args = parser.parse_args()
th.manual_seed(16)
print(args)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment