Unverified Commit 6e1be69a authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Refactor] Refactor GAT example (#4240)



* Refactor gat example

* Add ppi support

* Minor update

* Update

* Update

* Change valid_xxx to val_xxx

* Readme Update

* Update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 05d9d496
...@@ -2,54 +2,29 @@ Graph Attention Networks (GAT) ...@@ -2,54 +2,29 @@ Graph Attention Networks (GAT)
============ ============
- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903) - Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)
- Author's code repo (in Tensorflow): - Author's code repo (tensorflow implementation):
[https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT). [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
- Popular pytorch implementation: - Popular pytorch implementation:
[https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT). [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
Dependencies
------------
- torch v1.0: the autograd support for sparse mm is only available in v1.0.
- requests
- sklearn
```bash
pip install torch==1.0.0 requests
```
How to run How to run
---------- -------
Run with following:
```bash
python3 train.py --dataset=cora --gpu=0
```
Run with the following for multiclass node classification (available datasets: "cora", "citeseer", "pubmed")
```bash ```bash
python3 train.py --dataset=citeseer --gpu=0 --early-stop python3 train.py --dataset cora
``` ```
Run with the following for multilabel classification with PPI dataset
```bash ```bash
python3 train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001 --early-stop python3 train_ppi.py
``` ```
```bash > **_NOTE:_** Users may occasionally run into low accuracy issue (e.g., test accuracy < 0.8) due to overfitting. This can be resolved by adding Early Stopping or reducing maximum number of training epochs.
python3 train_ppi.py --gpu=0
```
Results Summary
------- -------
* cora: ~0.821
| Dataset | Test Accuracy | Time(s) | Baseline#1 times(s) | Baseline#2 times(s) | * citeseer: ~0.710
| -------- | ------------- | ------- | ------------------- | ------------------- | * pubmed: ~0.780
| Cora | 84.02(0.40) | 0.0113 | 0.0982 (**8.7x**) | 0.0424 (**3.8x**) | * ppi: ~0.9744
| Citeseer | 70.91(0.79) | 0.0111 | n/a | n/a |
| Pubmed | 78.57(0.75) | 0.0115 | n/a | n/a |
| PPI | 0.9836 | n/a | n/a | n/a |
* All the accuracy numbers are obtained after 300 epochs.
* The time measures how long it takes to train one epoch.
* All time is measured on EC2 p3.2xlarge instance w/ V100 GPU.
* Baseline#1: [https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
* Baseline#2: [https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn import GATConv
class GAT(nn.Module):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
if num_layers > 1:
# input projection (no residual)
self.gat_layers.append(GATConv(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation))
# hidden layers
for l in range(1, num_layers-1):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation))
# output projection
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None))
else:
self.gat_layers.append(GATConv(
in_dim, num_classes, heads[0],
feat_drop, attn_drop, negative_slope, residual, None))
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](self.g, h)
h = h.flatten(1) if l != self.num_layers - 1 else h.mean(1)
return h
"""
Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import argparse
import numpy as np
import networkx as nx
import time
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl.nn as dglnn
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop
import argparse
from gat import GAT class GAT(nn.Module):
from utils import EarlyStopping def __init__(self,in_size, hid_size, out_size, heads):
super().__init__()
self.gat_layers = nn.ModuleList()
def accuracy(logits, labels): # two-layer GAT
self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], feat_drop=0.6, attn_drop=0.6, activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], out_size, heads[1], feat_drop=0.6, attn_drop=0.6, activation=None))
def forward(self, g, inputs):
h = inputs
for i, layer in enumerate(self.gat_layers):
h = layer(g, h)
if i == 1: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
return h
def evaluate(g, features, labels, mask, model):
model.eval()
with torch.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def train(g, features, labels, masks, model):
# define train/val samples, loss function and optimizer
train_mask = masks[0]
val_mask = masks[1]
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
def evaluate(model, features, labels, mask): #training loop
model.eval() for epoch in range(200):
with torch.no_grad(): model.train()
logits = model(features) logits = model(g, features)
logits = logits[mask] loss = loss_fcn(logits[train_mask], labels[train_mask])
labels = labels[mask] optimizer.zero_grad()
return accuracy(logits, labels) loss.backward()
optimizer.step()
acc = evaluate(g, features, labels, val_mask, model)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} "
. format(epoch, loss.item(), acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').")
args = parser.parse_args()
print(f'Training with DGL built-in GATConv module.')
def main(args):
# load and preprocess dataset # load and preprocess dataset
transform = AddSelfLoop() # by default, it will first remove self-loops to prevent duplication
if args.dataset == 'cora': if args.dataset == 'cora':
data = CoraGraphDataset() data = CoraGraphDataset(transform=transform)
elif args.dataset == 'citeseer': elif args.dataset == 'citeseer':
data = CiteseerGraphDataset() data = CiteseerGraphDataset(transform=transform)
elif args.dataset == 'pubmed': elif args.dataset == 'pubmed':
data = PubmedGraphDataset() data = PubmedGraphDataset(transform=transform)
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cuda = False g = g.int().to(device)
else:
cuda = True
g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata['feat']
labels = g.ndata['label'] labels = g.ndata['label']
train_mask = g.ndata['train_mask'] masks = g.ndata['train_mask'], g.ndata['val_mask'], g.ndata['test_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = g.number_of_edges()
print("""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item()))
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()
# create model
heads = ([args.num_heads] * (args.num_layers-1)) + [args.num_out_heads]
model = GAT(g,
args.num_layers,
num_feats,
args.num_hidden,
n_classes,
heads,
F.elu,
args.in_drop,
args.attn_drop,
args.negative_slope,
args.residual)
print(model)
if args.early_stop:
stopper = EarlyStopping(patience=100)
if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# initialize graph
dur = []
for epoch in range(args.epochs):
model.train()
if epoch >= 3:
if cuda:
torch.cuda.synchronize()
t0 = time.time()
# forward
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
if cuda:
torch.cuda.synchronize()
dur.append(time.time() - t0)
train_acc = accuracy(logits[train_mask], labels[train_mask])
if args.fastmode:
val_acc = accuracy(logits[val_mask], labels[val_mask])
else:
val_acc = evaluate(model, features, labels, val_mask)
if args.early_stop:
if stopper.step(val_acc, model):
break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" # create GAT model
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". in_size = features.shape[1]
format(epoch, np.mean(dur), loss.item(), train_acc, out_size = data.num_classes
val_acc, n_edges / np.mean(dur) / 1000)) model = GAT(in_size, 8, out_size, heads=[8,1]).to(device)
print() # model training
if args.early_stop: print('Training...')
model.load_state_dict(torch.load('es_checkpoint.pt')) train(g, features, labels, masks, model)
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8,
help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=2,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4,
help="weight decay")
parser.add_argument('--negative-slope', type=float, default=0.2,
help="the negative slope of leaky relu")
parser.add_argument('--early-stop', action='store_true', default=False,
help="indicates whether to use early stop or not")
parser.add_argument('--fastmode', action="store_true", default=False,
help="skip re-evaluate the validation set")
args = parser.parse_args()
print(args)
main(args) # test the model
print('Testing...')
acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc))
"""
Graph Attention Networks (PPI Dataset) in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
Compared with the original paper, this code implements
early stopping.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import numpy as np import numpy as np
import torch import torch
import dgl import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import argparse import dgl.nn as dglnn
from sklearn.metrics import f1_score
from gat import GAT
from dgl.data.ppi import PPIDataset from dgl.data.ppi import PPIDataset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from sklearn.metrics import f1_score
def evaluate(feats, model, subgraph, labels, loss_fcn): class GAT(nn.Module):
with torch.no_grad(): def __init__(self, in_size, hid_size, out_size, heads):
super().__init__()
self.gat_layers = nn.ModuleList()
# three-layer GAT
self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], hid_size, heads[1], residual=True, activation=F.elu))
self.gat_layers.append(dglnn.GATConv(hid_size*heads[1], out_size, heads[2], residual=True, activation=None))
def forward(self, g, inputs):
h = inputs
for i, layer in enumerate(self.gat_layers):
h = layer(g, h)
if i == 2: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
return h
def evaluate(g, features, labels, model):
model.eval() model.eval()
model.g = subgraph with torch.no_grad():
for layer in model.gat_layers: output = model(g, features)
layer.g = subgraph pred = np.where(output.data.cpu().numpy() >= 0, 1, 0)
output = model(feats.float()) score = f1_score(labels.data.cpu().numpy(), pred, average='micro')
loss_data = loss_fcn(output, labels.float()) return score
predict = np.where(output.data.cpu().numpy() >= 0., 1, 0)
score = f1_score(labels.data.cpu().numpy(),
predict, average='micro')
return score, loss_data.item()
def main(args): def evaluate_in_batches(dataloader, device, model):
if args.gpu<0: total_score = 0
device = torch.device("cpu") for batch_id, batched_graph in enumerate(dataloader):
else: batched_graph = batched_graph.to(device)
device = torch.device("cuda:" + str(args.gpu)) features = batched_graph.ndata['feat']
labels = batched_graph.ndata['label']
score = evaluate(batched_graph, features, labels, model)
total_score += score
return total_score / (batch_id + 1) # return average score
batch_size = args.batch_size def train(train_dataloader, val_dataloader, device, model):
cur_step = 0 # define loss function and optimizer
patience = args.patience loss_fcn = nn.BCEWithLogitsLoss()
best_score = -1 optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=0)
best_loss = 10000
# define loss function # training loop
loss_fcn = torch.nn.BCEWithLogitsLoss() for epoch in range(400):
# create the dataset
train_dataset = PPIDataset(mode='train')
valid_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')
train_dataloader = GraphDataLoader(train_dataset, batch_size=batch_size)
valid_dataloader = GraphDataLoader(valid_dataset, batch_size=batch_size)
test_dataloader = GraphDataLoader(test_dataset, batch_size=batch_size)
g = train_dataset[0]
n_classes = train_dataset.num_labels
num_feats = g.ndata['feat'].shape[1]
g = g.int().to(device)
heads = ([args.num_heads] * (args.num_layers-1)) + [args.num_out_heads]
# define the model
model = GAT(g,
args.num_layers,
num_feats,
args.num_hidden,
n_classes,
heads,
F.elu,
args.in_drop,
args.attn_drop,
args.alpha,
args.residual)
# define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
model = model.to(device)
for epoch in range(args.epochs):
model.train() model.train()
loss_list = [] logits = []
for batch, subgraph in enumerate(train_dataloader): total_loss = 0
subgraph = subgraph.to(device) # mini-batch loop
model.g = subgraph for batch_id, batched_graph in enumerate(train_dataloader):
for layer in model.gat_layers: batched_graph = batched_graph.to(device)
layer.g = subgraph features = batched_graph.ndata['feat'].float()
logits = model(subgraph.ndata['feat'].float()) labels = batched_graph.ndata['label'].float()
loss = loss_fcn(logits, subgraph.ndata['label']) logits = model(batched_graph, features)
loss = loss_fcn(logits, labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
loss_list.append(loss.item()) total_loss += loss.item()
loss_data = np.array(loss_list).mean() print("Epoch {:05d} | Loss {:.4f} |". format(epoch, total_loss / (batch_id + 1) ))
print("Epoch {:05d} | Loss: {:.4f}".format(epoch + 1, loss_data))
if epoch % 5 == 0: if (epoch + 1) % 5 == 0:
score_list = [] avg_score = evaluate_in_batches(val_dataloader, device, model) # evaluate F1-score instead of loss
val_loss_list = [] print(" Acc. (F1-score) {:.4f} ". format(avg_score))
for batch, subgraph in enumerate(valid_dataloader):
subgraph = subgraph.to(device)
score, val_loss = evaluate(subgraph.ndata['feat'], model, subgraph, subgraph.ndata['label'], loss_fcn)
score_list.append(score)
val_loss_list.append(val_loss)
mean_score = np.array(score_list).mean()
mean_val_loss = np.array(val_loss_list).mean()
print("Val F1-Score: {:.4f} ".format(mean_score))
# early stop
if mean_score > best_score or best_loss > mean_val_loss:
if mean_score > best_score and best_loss > mean_val_loss:
val_early_loss = mean_val_loss
val_early_score = mean_score
best_score = np.max((mean_score, best_score))
best_loss = np.min((best_loss, mean_val_loss))
cur_step = 0
else:
cur_step += 1
if cur_step == patience:
break
test_score_list = []
for batch, subgraph in enumerate(test_dataloader):
subgraph = subgraph.to(device)
score, test_loss = evaluate(subgraph.ndata['feat'], model, subgraph, subgraph.ndata['label'], loss_fcn)
test_score_list.append(score)
print("Test F1-Score: {:.4f}".format(np.array(test_score_list).mean()))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT') print(f'Training PPI Dataset with DGL built-in GATConv module.')
parser.add_argument("--gpu", type=int, default=-1, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=400, # load and preprocess datasets
help="number of training epochs") train_dataset = PPIDataset(mode='train')
parser.add_argument("--num-heads", type=int, default=4, val_dataset = PPIDataset(mode='valid')
help="number of hidden attention heads") test_dataset = PPIDataset(mode='test')
parser.add_argument("--num-out-heads", type=int, default=6, features = train_dataset[0].ndata['feat']
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=3, # create GAT model
help="number of hidden layers") in_size = features.shape[1]
parser.add_argument("--num-hidden", type=int, default=256, out_size = train_dataset.num_labels
help="number of hidden units") model = GAT(in_size, 256, out_size, heads=[4,4,6]).to(device)
parser.add_argument("--residual", action="store_true", default=True,
help="use residual connection") # model training
parser.add_argument("--in-drop", type=float, default=0, print('Training...')
help="input feature dropout") train_dataloader = GraphDataLoader(train_dataset, batch_size=2)
parser.add_argument("--attn-drop", type=float, default=0, val_dataloader = GraphDataLoader(val_dataset, batch_size=2)
help="attention dropout") train(train_dataloader, val_dataloader, device, model)
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=0,
help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu")
parser.add_argument('--batch-size', type=int, default=2,
help="batch size used for training, validation and test")
parser.add_argument('--patience', type=int, default=10,
help="used for early stop")
args = parser.parse_args()
print(args)
main(args) # test the model
print('Testing...')
test_dataloader = GraphDataLoader(test_dataset, batch_size=2)
avg_score = evaluate_in_batches(test_dataloader, device, model)
print("Test Accuracy (F1-score) {:.4f}".format(avg_score))
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
def step(self, acc, model):
score = acc
if self.best_score is None:
self.best_score = score
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(model)
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
torch.save(model.state_dict(), 'es_checkpoint.pt')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment