Commit 788d8dd4 authored by Hao Zhang's avatar Hao Zhang Committed by Minjie Wang
Browse files

[Model]PPI dataloader and inductive learning script. (#395)

* Create ppi.py

* Create train_ppi.py

* Update train_ppi.py

* Update train_ppi.py

* Create gat.py

* Update train.py

* Update train_ppi.py

* Update ppi.py

* Update train_ppi.py

* Update ppi.py

* Update train_ppi.py

* Update train_ppi.py

* Update ppi.py

* Update train_ppi.py

* update docs and readme
parent 1ea0bcf4
......@@ -32,3 +32,9 @@ Mini graph classification dataset
.. autoclass:: MiniGCDataset
:members: __getitem__, __len__, num_classes
Protein-Protein Interaction dataset
```````````````````````````````````
.. autoclass:: PPIDataset
:members: __getitem__, __len__
......@@ -11,6 +11,7 @@ Dependencies
------------
- torch v1.0: the autograd support for sparse mm is only available in v1.0.
- requests
- sklearn
```bash
pip install torch==1.0.0 requests
......@@ -33,6 +34,10 @@ python train.py --dataset=citeseer --gpu=0
python train.py --dataset=pubmed --gpu=0 --num-out-heads=8 --weight-decay=0.001
```
```bash
python train_ppi.py --gpu=0
```
Results
-------
......
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import torch
import torch.nn as nn
import dgl.function as fn
class GraphAttention(nn.Module):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = lambda x : x
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
self.leaky_relu = nn.LeakyReLU(alpha)
self.residual = residual
if residual:
if in_dim != out_dim:
self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
else:
self.res_fc = None
def forward(self, inputs):
# prepare
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
head_ft = ft.transpose(0, 1) # HxNxD'
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
self.edge_softmax()
# 2. compute the aggregated node features scaled by the dropped,
# unnormalized attention values.
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
# 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD'
# 4. residual
if self.residual:
if self.res_fc is not None:
resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
else:
resval = torch.unsqueeze(h, 1) # Nx1xD'
ret = resval + ret
return ret
def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
return {'a' : a}
def edge_softmax(self):
# compute the max
self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max'))
# minus the max and exp
self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])})
# compute dropout
self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])})
# compute normalizer
self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z'))
class GAT(nn.Module):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h).flatten(1)
h = self.activation(h)
# output projection
logits = self.gat_layers[-1](h).mean(1)
return logits
"""
Graph Attention Networks in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
Compared with the original paper, this code does not implement
early stopping.
References
----------
Paper: https://arxiv.org/abs/1710.10903
......@@ -16,129 +14,10 @@ import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
import dgl.function as fn
class GraphAttention(nn.Module):
def __init__(self,
g,
in_dim,
out_dim,
num_heads,
feat_drop,
attn_drop,
alpha,
residual=False):
super(GraphAttention, self).__init__()
self.g = g
self.num_heads = num_heads
self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
if feat_drop:
self.feat_drop = nn.Dropout(feat_drop)
else:
self.feat_drop = lambda x : x
if attn_drop:
self.attn_drop = nn.Dropout(attn_drop)
else:
self.attn_drop = lambda x : x
self.attn_l = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
self.attn_r = nn.Parameter(torch.Tensor(size=(num_heads, out_dim, 1)))
nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
self.leaky_relu = nn.LeakyReLU(alpha)
self.residual = residual
if residual:
if in_dim != out_dim:
self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
else:
self.res_fc = None
def forward(self, inputs):
# prepare
h = self.feat_drop(inputs) # NxD
ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
head_ft = ft.transpose(0, 1) # HxNxD'
a1 = torch.bmm(head_ft, self.attn_l).transpose(0, 1) # NxHx1
a2 = torch.bmm(head_ft, self.attn_r).transpose(0, 1) # NxHx1
self.g.ndata.update({'ft' : ft, 'a1' : a1, 'a2' : a2})
# 1. compute edge attention
self.g.apply_edges(self.edge_attention)
# 2. compute softmax in two parts: exp(x - max(x)) and sum(exp(x - max(x)))
self.edge_softmax()
# 2. compute the aggregated node features scaled by the dropped,
# unnormalized attention values.
self.g.update_all(fn.src_mul_edge('ft', 'a_drop', 'ft'), fn.sum('ft', 'ft'))
# 3. apply normalizer
ret = self.g.ndata['ft'] / self.g.ndata['z'] # NxHxD'
# 4. residual
if self.residual:
if self.res_fc is not None:
resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1)) # NxHxD'
else:
resval = torch.unsqueeze(h, 1) # Nx1xD'
ret = resval + ret
return ret
def edge_attention(self, edges):
# an edge UDF to compute unnormalized attention values from src and dst
a = self.leaky_relu(edges.src['a1'] + edges.dst['a2'])
return {'a' : a}
def edge_softmax(self):
# compute the max
self.g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max'))
# minus the max and exp
self.g.apply_edges(lambda edges : {'a' : torch.exp(edges.data['a'] - edges.dst['a_max'])})
# compute dropout
self.g.apply_edges(lambda edges : {'a_drop' : self.attn_drop(edges.data['a'])})
# compute normalizer
self.g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z'))
class GAT(nn.Module):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
alpha,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GraphAttention(
g, in_dim, num_hidden, heads[0], feat_drop, attn_drop, alpha, False))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, alpha, residual))
# output projection
self.gat_layers.append(GraphAttention(
g, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, alpha, residual))
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](h).flatten(1)
h = self.activation(h)
# output projection
logits = self.gat_layers[-1](h).mean(1)
return logits
from gat import GAT
def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1)
......
"""
Graph Attention Networks (PPI Dataset) in DGL using SPMV optimization.
Multiple heads are also batched together for faster training.
Compared with the original paper, this code implements
early stopping.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
import numpy as np
import torch
import dgl
import torch.nn.functional as F
import argparse
from sklearn.metrics import f1_score
from gat import GAT
from dgl.data.ppi import PPIDataset
from torch.utils.data import DataLoader
def collate(sample):
graphs, feats, labels =map(list, zip(*sample))
graph = dgl.batch(graphs)
feats = torch.from_numpy(np.concatenate(feats))
labels = torch.from_numpy(np.concatenate(labels))
return graph, feats, labels
def evaluate(feats, model, subgraph, labels, loss_fcn):
with torch.no_grad():
model.eval()
model.g = subgraph
for layer in model.gat_layers:
layer.g = subgraph
output = model(feats.float())
loss_data = loss_fcn(output, labels.float())
predict = np.where(output.data.cpu().numpy() >= 0.5, 1, 0)
score = f1_score(labels.data.cpu().numpy(),
predict, average='micro')
return score, loss_data.item()
def main(args):
if args.gpu<0:
device = torch.device("cpu")
else:
device = torch.device("cuda:" + str(args.gpu))
batch_size = args.batch_size
cur_step = 0
patience = args.patience
best_score = -1
best_loss = 10000
# define loss function
loss_fcn = torch.nn.BCEWithLogitsLoss()
# create the dataset
train_dataset = PPIDataset(mode='train')
valid_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)
n_classes = train_dataset.labels.shape[1]
num_feats = train_dataset.features.shape[1]
g = train_dataset.graph
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
# define the model
model = GAT(g,
args.num_layers,
num_feats,
args.num_hidden,
n_classes,
heads,
F.elu,
args.in_drop,
args.attn_drop,
args.alpha,
args.residual)
# define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
model = model.to(device)
for epoch in range(args.epochs):
model.train()
loss_list = []
for batch, data in enumerate(train_dataloader):
subgraph, feats, labels = data
feats = feats.to(device)
labels = labels.to(device)
model.g = subgraph
for layer in model.gat_layers:
layer.g = subgraph
logits = model(feats.float())
loss = loss_fcn(logits, labels.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
loss_data = np.array(loss_list).mean()
print("Epoch {:05d} | Loss: {:.4f}".format(epoch + 1, loss_data))
if epoch % 5 == 0:
score_list = []
val_loss_list = []
for batch, valid_data in enumerate(valid_dataloader):
subgraph, feats, labels = valid_data
feats = feats.to(device)
labels = labels.to(device)
score, val_loss = evaluate(feats.float(), model, subgraph, labels.float(), loss_fcn)
score_list.append(score)
val_loss_list.append(val_loss)
mean_score = np.array(score_list).mean()
mean_val_loss = np.array(val_loss_list).mean()
print("F1-Score: {:.4f} ".format(mean_score))
# early stop
if mean_score > best_score or best_loss > mean_val_loss:
if mean_score > best_score and best_loss > mean_val_loss:
val_early_loss = mean_val_loss
val_early_score = mean_score
best_score = np.max((mean_score, best_score))
best_loss = np.min((best_loss, mean_val_loss))
cur_step = 0
else:
cur_step += 1
if cur_step == patience:
break
test_score_list = []
for batch, test_data in enumerate(test_dataloader):
subgraph, feats, labels = test_data
feats = feats.to(device)
labels = labels.to(device)
test_score_list.append(evaluate(feats, model, subgraph, labels.float(), loss_fcn)[0])
print("F1-Score: {:.4f}".format(np.array(test_score_list).mean()))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
parser.add_argument("--gpu", type=int, default=-1,
help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=400,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=4,
help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=6,
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=2,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=256,
help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=True,
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=0,
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=0,
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=0,
help="weight decay")
parser.add_argument('--alpha', type=float, default=0.2,
help="the negative slop of leaky relu")
parser.add_argument('--batch-size', type=int, default=2,
help="batch size used for training, validation and test")
parser.add_argument('--patience', type=int, default=10,
help="used for early stop")
args = parser.parse_args()
print(args)
main(args)
......@@ -8,6 +8,7 @@ from .tree import *
from .utils import *
from .sbm import SBMMixture
from .reddit import RedditDataset
from .ppi import PPIDataset
def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=False,
......
"""PPI Dataset.
(zhang hao): Used for inductive learning.
"""
import json
import numpy as np
import networkx as nx
from networkx.readwrite import json_graph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
_url = 'dataset/ppi.zip'
class PPIDataset(object):
"""A toy Protein-Protein Interaction network dataset.
Adapted from https://github.com/williamleif/GraphSAGE/tree/master/example_data.
The dataset contains 24 graphs. The average number of nodes per graph
is 2372. Each node has 50 features and 121 labels.
We use 20 graphs for training, 2 for validation and 2 for testing.
"""
def __init__(self, mode):
"""Initialize the dataset.
Paramters
---------
mode : str
('train', 'valid', 'test').
"""
self.mode = mode
self._load()
self._preprocess()
def _load(self):
"""Loads input data.
train/test/valid_graph.json => the graph data used for training,
test and validation as json format;
train/test/valid_feats.npy => the feature vectors of nodes as
numpy.ndarry object, it's shape is [n, v],
n is the number of nodes, v is the feature's dimension;
train/test/valid_labels.npy=> the labels of the input nodes, it
is a numpy ndarry, it's like[[0, 0, 1, ... 0],
[0, 1, 1, 0 ...1]], shape of it is n*h, n is the number of nodes,
h is the label's dimension;
train/test/valid/_graph_id.npy => the element in it indicates which
graph the nodes belong to, it is a one dimensional numpy.ndarray
object and the length of it is equal the number of nodes,
it's like [1, 1, 2, 1...20].
"""
name = 'ppi'
dir = get_download_dir()
zip_file_path = '{}/{}.zip'.format(dir, name)
download(_get_dgl_url(_url), path=zip_file_path)
extract_archive(zip_file_path,
'{}/{}'.format(dir, name))
print('Loading G...')
if self.mode == 'train':
with open('{}/ppi/train_graph.json'.format(dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/train_labels.npy'.format(dir))
self.features = np.load('{}/ppi/train_feats.npy'.format(dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(dir))
if self.mode == 'valid':
with open('{}/ppi/valid_graph.json'.format(dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/valid_labels.npy'.format(dir))
self.features = np.load('{}/ppi/valid_feats.npy'.format(dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(dir))
if self.mode == 'test':
with open('{}/ppi/test_graph.json'.format(dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/test_labels.npy'.format(dir))
self.features = np.load('{}/ppi/test_feats.npy'.format(dir))
self.graph = DGLGraph(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(dir))
def _preprocess(self):
if self.mode == 'train':
self.train_mask_list = []
self.train_graphs = []
self.train_labels = []
for train_graph_id in range(1, 21):
train_graph_mask = np.where(self.graph_id == train_graph_id)[0]
self.train_mask_list.append(train_graph_mask)
self.train_graphs.append(self.graph.subgraph(train_graph_mask))
self.train_labels.append(self.labels[train_graph_mask])
if self.mode == 'valid':
self.valid_mask_list = []
self.valid_graphs = []
self.valid_labels = []
for valid_graph_id in range(21, 23):
valid_graph_mask = np.where(self.graph_id == valid_graph_id)[0]
self.valid_mask_list.append(valid_graph_mask)
self.valid_graphs.append(self.graph.subgraph(valid_graph_mask))
self.valid_labels.append(self.labels[valid_graph_mask])
if self.mode == 'test':
self.test_mask_list = []
self.test_graphs = []
self.test_labels = []
for test_graph_id in range(23, 25):
test_graph_mask = np.where(self.graph_id == test_graph_id)[0]
self.test_mask_list.append(test_graph_mask)
self.test_graphs.append(self.graph.subgraph(test_graph_mask))
self.test_labels.append(self.labels[test_graph_mask])
def __len__(self):
"""Return number of samples in this dataset."""
if self.mode == 'train':
return len(self.train_mask_list)
if self.mode == 'valid':
return len(self.valid_mask_list)
if self.mode == 'test':
return len(self.test_mask_list)
def __getitem__(self, item):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, ndarray, ndarray)
The graph, features and its label.
"""
if self.mode == 'train':
return self.train_graphs[item], self.features[self.train_mask_list[item]], self.train_labels[item]
if self.mode == 'valid':
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
if self.mode == 'test':
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment