Unverified Commit f908f35c authored by RecLusIve-F's avatar RecLusIve-F Committed by GitHub
Browse files

[Example]Add P-GNN example (#3823)



* [Model]P-GNN

* updata

* [Example]P-GNN

* Update README.md
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent eec219ab
......@@ -181,6 +181,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
- <a name='geniepath'></a> Liu Z, et al. Geniepath: Graph neural networks with adaptive receptive paths. [Paper link](https://arxiv.org/abs/1802.00910).
- Example code: [PyTorch](../examples/pytorch/geniepath)
- Tags: Fraud detection, Node classification, Graph attention, LSTM, Adaptive receptive fields
- <a name='pgnn'></a> You J, et al. Position-aware graph neural networks. [Paper link](https://arxiv.org/abs/1906.04817).
- Example code: [PyTorch](../examples/pytorch/P-GNN)
- Tags: Positional encoding, Link prediction, Link-pair prediction
## 2018
......
# DGL Implementations of P-GNN
This DGL example implements the GNN model proposed in the paper [Position-aware Graph Neural Networks](http://proceedings.mlr.press/v97/you19b/you19b.pdf). For the original implementation, see [here](https://github.com/JiaxuanYou/P-GNN).
Contributor: [RecLusIve-F](https://github.com/RecLusIve-F)
## Requirements
The codebase is implemented in Python 3.8. For version requirement of packages, see below.
```
dgl 0.7.2
numpy 1.21.2
torch 1.10.1
networkx 2.6.3
scikit-learn 1.0.2
```
## Instructions to download datasets:
1. Download datasets from [here](https://github.com/RecLusIve-F/P-GNN-dgl/blob/master/data.zip?raw=true)
2. Extract zip folder in this directory
## Instructions for experiments
### Link prediction
```bash
# Communities-T
python main.py --task link
# Communities
python main.py --task link --inductive
```
### Link pair prediction
```bash
# Communities
python main.py --task link_pair --inductive
```
## Performance
### Link prediction (Grid-T and Communities-T refer to the transductive learning setting of Grid and Communities)
| Dataset | Communities-T | Communities |
| :------------------------------: | :-----------: | :-----------: |
| ROC AUC ( P-GNN-E-2L in Table 1) | 0.988 ± 0.003 | 0.985 ± 0.008 |
| ROC AUC (DGL: P-GNN-E-2L) | 0.984 ± 0.010 | 0.991 ± 0.004 |
### Link pair prediction
| Dataset | Communities |
| :------------------------------: | :---------: |
| ROC AUC ( P-GNN-E-2L in Table 1) | 1.0 ± 0.001 |
| ROC AUC (DGL: P-GNN-E-2L) | 1.0 ± 0.000 |
import os
import dgl
import torch
import numpy as np
import torch.nn as nn
from model import PGNN
from sklearn.metrics import roc_auc_score
from utils import get_dataset, preselect_anchor
import warnings
warnings.filterwarnings('ignore')
def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate((data['positive_edges_{}'.format(p)], data['negative_edges_{}'.format(p)]), axis=-1)
nodes_first = torch.index_select(out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device))
nodes_second = torch.index_select(out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device))
pred = torch.sum(nodes_first * nodes_second, dim=-1)
label_positive = torch.ones([data['positive_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype)
label_negative = torch.zeros([data['negative_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype)
label = torch.cat((label_positive, label_negative)).to(device)
loss = loss_func(pred, label)
if get_auc:
auc = roc_auc_score(label.flatten().cpu().numpy(), torch.sigmoid(pred).flatten().data.cpu().numpy())
return loss, auc
else:
return loss
def train_model(data, model, loss_func, optimizer, device, g_data):
model.train()
out = model(g_data)
loss = get_loss('train', data, out, loss_func, device, get_auc=False)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
return g_data
def eval_model(data, g_data, model, loss_func, device):
model.eval()
out = model(g_data)
# train loss and auc
tmp_loss, auc_train = get_loss('train', data, out, loss_func, device)
loss_train = tmp_loss.cpu().data.numpy()
# val loss and auc
_, auc_val = get_loss('val', data, out, loss_func, device)
# test loss and auc
_, auc_test = get_loss('test', data, out, loss_func, device)
return loss_train, auc_train, auc_val, auc_test
def main(args):
# The mean and standard deviation of the experiment results
# are stored in the 'results' folder
if not os.path.isdir('results'):
os.mkdir('results')
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
print('Learning Type: {}'.format(['Transductive', 'Inductive'][args.inductive]),
'Task: {}'.format(args.task))
results = []
for repeat in range(args.repeat_num):
data = get_dataset(args)
# pre-sample anchor nodes and compute shortest distance values for all epochs
g_list, anchor_eid_list, dist_max_list, edge_weight_list = preselect_anchor(data, args)
# model
model = PGNN(input_dim=data['feature'].shape[1]).to(device)
# loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
loss_func = nn.BCEWithLogitsLoss()
best_auc_val = -1
best_auc_test = -1
for epoch in range(args.epoch_num):
if epoch == 200:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10
g = dgl.graph(g_list[epoch])
g.ndata['feat'] = torch.FloatTensor(data['feature'])
g.edata['sp_dist'] = torch.FloatTensor(edge_weight_list[epoch])
g_data = {
'graph': g.to(device),
'anchor_eid': anchor_eid_list[epoch],
'dists_max': dist_max_list[epoch]
}
train_model(data, model, loss_func, optimizer, device, g_data)
loss_train, auc_train, auc_val, auc_test = eval_model(
data, g_data, model, loss_func, device)
if auc_val > best_auc_val:
best_auc_val = auc_val
best_auc_test = auc_test
if epoch % args.epoch_log == 0:
print(repeat, epoch, 'Loss {:.4f}'.format(loss_train), 'Train AUC: {:.4f}'.format(auc_train),
'Val AUC: {:.4f}'.format(auc_val), 'Test AUC: {:.4f}'.format(auc_test),
'Best Val AUC: {:.4f}'.format(best_auc_val), 'Best Test AUC: {:.4f}'.format(best_auc_test))
results.append(best_auc_test)
results = np.array(results)
results_mean = np.mean(results).round(6)
results_std = np.std(results).round(6)
print('-----------------Final-------------------')
print(results_mean, results_std)
with open('results/{}_{}_{}.txt'.format(['Transductive', 'Inductive'][args.inductive], args.task,
args.k_hop_dist), 'w') as f:
f.write('{}, {}\n'.format(results_mean, results_std))
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--task', type=str, default='link', choices=['link', 'link_pair'])
parser.add_argument('--inductive', action='store_true',
help='Inductive learning or transductive learning')
parser.add_argument('--k_hop_dist', default=-1, type=int,
help='K-hop shortest path distance, -1 means exact shortest path.')
parser.add_argument('--epoch_num', type=int, default=2000)
parser.add_argument('--repeat_num', type=int, default=10)
parser.add_argument('--epoch_log', type=int, default=100)
args = parser.parse_args()
main(args)
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
class PGNN_layer(nn.Module):
def __init__(self, input_dim, output_dim):
super(PGNN_layer, self).__init__()
self.input_dim = input_dim
self.linear_hidden_u = nn.Linear(input_dim, output_dim)
self.linear_hidden_v = nn.Linear(input_dim, output_dim)
self.linear_out_position = nn.Linear(output_dim, 1)
self.act = nn.ReLU()
def forward(self, graph, feature, anchor_eid, dists_max):
with graph.local_scope():
u_feat = self.linear_hidden_u(feature)
v_feat = self.linear_hidden_v(feature)
graph.srcdata.update({'u_feat': u_feat})
graph.dstdata.update({'v_feat': v_feat})
graph.apply_edges(fn.u_mul_e('u_feat', 'sp_dist', 'u_message'))
graph.apply_edges(fn.v_add_e('v_feat', 'u_message', 'message'))
messages = torch.index_select(graph.edata['message'], 0,
torch.LongTensor(anchor_eid).to(feature.device))
messages = messages.reshape(dists_max.shape[0], dists_max.shape[1], messages.shape[-1])
messages = self.act(messages) # n*m*d
out_position = self.linear_out_position(messages).squeeze(-1) # n*m_out
out_structure = torch.mean(messages, dim=1) # n*d
return out_position, out_structure
class PGNN(nn.Module):
def __init__(self, input_dim, feature_dim=32, dropout=0.5):
super(PGNN, self).__init__()
self.dropout = nn.Dropout(dropout)
self.linear_pre = nn.Linear(input_dim, feature_dim)
self.conv_first = PGNN_layer(feature_dim, feature_dim)
self.conv_out = PGNN_layer(feature_dim, feature_dim)
def forward(self, data):
x = data['graph'].ndata['feat']
graph = data['graph']
x = self.linear_pre(x)
x_position, x = self.conv_first(graph, x, data['anchor_eid'], data['dists_max'])
x = self.dropout(x)
x_position, x = self.conv_out(graph, x, data['anchor_eid'], data['dists_max'])
x_position = F.normalize(x_position, p=2, dim=-1)
return x_position
import torch
import random
import numpy as np
import networkx as nx
from tqdm.auto import tqdm
import multiprocessing as mp
from multiprocessing import get_context
def get_communities(remove_feature):
community_size = 20
# Create 20 cliques (communities) of size 20,
# then rewire a single edge in each clique to a node in an adjacent clique
graph = nx.connected_caveman_graph(20, community_size)
# Randomly rewire 1% edges
node_list = list(graph.nodes)
for (u, v) in graph.edges():
if random.random() < 0.01:
x = random.choice(node_list)
if graph.has_edge(u, x):
continue
graph.remove_edge(u, v)
graph.add_edge(u, x)
# remove self-loops
graph.remove_edges_from(nx.selfloop_edges(graph))
edge_index = np.array(list(graph.edges))
# Add (i, j) for an edge (j, i)
edge_index = np.concatenate((edge_index, edge_index[:, ::-1]), axis=0)
edge_index = torch.from_numpy(edge_index).long().permute(1, 0)
n = graph.number_of_nodes()
label = np.zeros((n, n), dtype=int)
for u in node_list:
# the node IDs are simply consecutive integers from 0
for v in range(u):
if u // community_size == v // community_size:
label[u, v] = 1
if remove_feature:
feature = torch.ones((n, 1))
else:
rand_order = np.random.permutation(n)
feature = np.identity(n)[:, rand_order]
data = {
'edge_index': edge_index,
'feature': feature,
'positive_edges': np.stack(np.nonzero(label)),
'num_nodes': feature.shape[0]
}
return data
def to_single_directed(edges):
edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int)
j = 0
for i in range(edges.shape[1]):
if edges[0, i] < edges[1, i]:
edges_new[:, j] = edges[:, i]
j += 1
return edges_new
# each node at least remain in the new graph
def split_edges(p, edges, data, non_train_ratio=0.2):
e = edges.shape[1]
edges = edges[:, np.random.permutation(e)]
split1 = int((1 - non_train_ratio) * e)
split2 = int((1 - non_train_ratio / 2) * e)
data.update({
'{}_edges_train'.format(p): edges[:, :split1], # 80%
'{}_edges_val'.format(p): edges[:, split1:split2], # 10%
'{}_edges_test'.format(p): edges[:, split2:] # 10%
})
def to_bidirected(edges):
return np.concatenate((edges, edges[::-1, :]), axis=-1)
def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
positive_edge_set = []
positive_edges = to_bidirected(positive_edges)
for i in range(positive_edges.shape[1]):
positive_edge_set.append(tuple(positive_edges[:, i]))
positive_edge_set = set(positive_edge_set)
negative_edges = np.zeros((2, num_negative_edges), dtype=positive_edges.dtype)
for i in range(num_negative_edges):
while True:
mask_temp = tuple(np.random.choice(num_nodes, size=(2,), replace=False))
if mask_temp not in positive_edge_set:
negative_edges[:, i] = mask_temp
break
return negative_edges
def get_pos_neg_edges(data, infer_link_positive=True):
if infer_link_positive:
data['positive_edges'] = to_single_directed(data['edge_index'].numpy())
split_edges('positive', data['positive_edges'], data)
# resample edge mask link negative
negative_edges = get_negative_edges(data['positive_edges'], data['num_nodes'],
num_negative_edges=data['positive_edges'].shape[1])
split_edges('negative', negative_edges, data)
return data
def shortest_path(graph, node_range, cutoff):
dists_dict = {}
for node in tqdm(node_range, leave=False):
dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff)
return dists_dict
def merge_dicts(dicts):
result = {}
for dictionary in dicts:
result.update(dictionary)
return result
def all_pairs_shortest_path(graph, cutoff=None, num_workers=4):
nodes = list(graph.nodes)
random.shuffle(nodes)
pool = mp.Pool(processes=num_workers)
interval_size = len(nodes) / num_workers
results = [pool.apply_async(shortest_path, args=(
graph, nodes[int(interval_size * i): int(interval_size * (i + 1))], cutoff))
for i in range(num_workers)]
output = [p.get() for p in results]
dists_dict = merge_dicts(output)
pool.close()
pool.join()
return dists_dict
def precompute_dist_data(edge_index, num_nodes, approximate=0):
"""
Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
:return:
"""
graph = nx.Graph()
edge_list = edge_index.transpose(1, 0).tolist()
graph.add_edges_from(edge_list)
n = num_nodes
dists_array = np.zeros((n, n))
dists_dict = all_pairs_shortest_path(graph, cutoff=approximate if approximate > 0 else None)
node_list = graph.nodes()
for node_i in node_list:
shortest_dist = dists_dict[node_i]
for node_j in node_list:
dist = shortest_dist.get(node_j, -1)
if dist != -1:
dists_array[node_i, node_j] = 1 / (dist + 1)
return dists_array
def get_dataset(args):
# Generate graph data
data_info = get_communities(args.inductive)
# Get positive and negative edges
data = get_pos_neg_edges(data_info, infer_link_positive=True if args.task == 'link' else False)
# Pre-compute shortest path length
if args.task == 'link':
dists_removed = precompute_dist_data(data['positive_edges_train'], data['num_nodes'],
approximate=args.k_hop_dist)
data['dists'] = torch.from_numpy(dists_removed).float()
data['edge_index'] = torch.from_numpy(to_bidirected(data['positive_edges_train'])).long()
else:
dists = precompute_dist_data(data['edge_index'].numpy(), data['num_nodes'],
approximate=args.k_hop_dist)
data['dists'] = torch.from_numpy(dists).float()
return data
def get_anchors(n):
"""Get a list of NumPy arrays, each of them is an anchor node set"""
m = int(np.log2(n))
anchor_set_id = []
for i in range(m):
anchor_size = int(n / np.exp2(i + 1))
for _ in range(m):
anchor_set_id.append(np.random.choice(n, size=anchor_size, replace=False))
return anchor_set_id
def get_dist_max(anchor_set_id, dist):
# N x K, N is number of nodes, K is the number of anchor sets
dist_max = torch.zeros((dist.shape[0], len(anchor_set_id)))
dist_argmax = torch.zeros((dist.shape[0], len(anchor_set_id))).long()
for i in range(len(anchor_set_id)):
temp_id = torch.as_tensor(anchor_set_id[i], dtype=torch.long)
# Get reciprocal of shortest distance to each node in the i-th anchor set
dist_temp = torch.index_select(dist, 1, temp_id)
# For each node in the graph, find its closest anchor node in the set
# and the reciprocal of shortest distance
dist_max_temp, dist_argmax_temp = torch.max(dist_temp, dim=-1)
dist_max[:, i] = dist_max_temp
dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp)
return dist_max, dist_argmax
def get_a_graph(dists_max, dists_argmax):
src = []
dst = []
real_src = []
real_dst = []
edge_weight = []
dists_max = dists_max.numpy()
for i in range(dists_max.shape[0]):
# Get unique closest anchor nodes for node i across all anchor sets
tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(dists_argmax[i, :], True)
src.extend([i] * tmp_dists_argmax.shape[0])
real_src.extend([i] * dists_argmax[i, :].shape[0])
real_dst.extend(list(dists_argmax[i, :].numpy()))
dst.extend(list(tmp_dists_argmax))
edge_weight.extend(dists_max[i, tmp_dists_argmax_idx].tolist())
eid_dict = {(u, v): i for i, (u, v) in enumerate(list(zip(dst, src)))}
anchor_eid = [eid_dict.get((u, v)) for u, v in zip(real_dst, real_src)]
g = (dst, src)
return g, anchor_eid, edge_weight
def get_graphs(data, anchor_sets):
graphs = []
anchor_eids = []
dists_max_list = []
edge_weights = []
for anchor_set in tqdm(anchor_sets, leave=False):
dists_max, dists_argmax = get_dist_max(anchor_set, data['dists'])
g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax)
graphs.append(g)
anchor_eids.append(anchor_eid)
dists_max_list.append(dists_max)
edge_weights.append(edge_weight)
return graphs, anchor_eids, dists_max_list, edge_weights
def merge_result(outputs):
graphs = []
anchor_eids = []
dists_max_list = []
edge_weights = []
for g, anchor_eid, dists_max, edge_weight in outputs:
graphs.extend(g)
anchor_eids.extend(anchor_eid)
dists_max_list.extend(dists_max)
edge_weights.extend(edge_weight)
return graphs, anchor_eids, dists_max_list, edge_weights
def preselect_anchor(data, args, num_workers=4):
pool = get_context("spawn").Pool(processes=num_workers)
# Pre-compute anchor sets, a collection of anchor sets per epoch
anchor_set_ids = [get_anchors(data['num_nodes']) for _ in range(args.epoch_num)]
interval_size = len(anchor_set_ids) / num_workers
results = [pool.apply_async(get_graphs, args=(
data, anchor_set_ids[int(interval_size * i):int(interval_size * (i + 1))],))
for i in range(num_workers)]
output = [p.get() for p in results]
graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output)
pool.close()
pool.join()
return graphs, anchor_eids, dists_max_list, edge_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment