Unverified Commit 92bff4d2 authored by Ereboas's avatar Ereboas Committed by GitHub
Browse files

[Example] SEAL+NGNN for ogbl (#4550)



* Use black for formatting

* limit line width to 80 characters.

* Use a backslash instead of directly concatenating

* file structure adjustment.

* file structure adjustment(2)
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent ec4271bf
# NGNN + SEAL
## Introduction
This is a submission of implementing [NGNN](https://arxiv.org/abs/2111.11638) + [SEAL](https://arxiv.org/pdf/2010.16103.pdf) to OGB link prediction leaderboards. Some code is migrated from [https://github.com/facebookresearch/SEAL_OGB](https://github.com/facebookresearch/SEAL_OGB).
## Installation Requirements
```
ogb>=1.3.4
torch>=1.12.0
dgl>=0.8
scipy, numpy, tqdm...
```
## Experiments
We do not fix random seeds at all, and take over 10 runs for all models. All models are trained on a single T4 GPU with 16GB memory and 96 vCPUs.
### ogbl-ppa
#### performance
| | Test Hits@100 | Validation Hits@100 | #Parameters |
|:------------:|:-------------------:|:-----------------:|:------------:|
| SEAL | 48.80% ± 3.16% | 51.25% ± 2.52% | 709,122 |
| SEAL + NGNN | 59.71% ± 2.45% | 59.95% ± 2.05% | 735,426 |
#### Reproduction of performance
```{.bash}
python main.py --dataset ogbl-ppa --ngnn_type input --hidden_channels 48 --epochs 50 --lr 0.00015 --batch_size 128 --num_workers 48 --train_percent 5 --val_percent 8 --eval_hits_K 10 --use_feature --dynamic_train --dynamic_val --dynamic_test --runs 10
```
As training is very costly, we select the best model by evaluation on a subset of the validation edges and using a lower K for Hits@K. Then we do experiments on the full validation and test sets with the best model selected, and get the required metrics.
For all datasets, if you specify `--dynamic_train`, the enclosing subgraphs of the training links will be extracted on the fly instead of preprocessing and saving to disk. Similarly for `--dynamic_val` and `--dynamic_test`. You can increase `--num_workers` to accelerate the dynamic subgraph extraction process.
You can also specify the `val_percent` and `eval_hits_K` arguments in the above command to adjust the proportion of the validation dataset to use and the K to use for Hits@K.
## Reference
@article{DBLP:journals/corr/abs-2111-11638,
author = {Xiang Song and
Runjie Ma and
Jiahang Li and
Muhan Zhang and
David Paul Wipf},
title = {Network In Graph Neural Network},
journal = {CoRR},
volume = {abs/2111.11638},
year = {2021},
url = {https://arxiv.org/abs/2111.11638},
eprinttype = {arXiv},
eprint = {2111.11638},
timestamp = {Fri, 26 Nov 2021 13:48:43 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2111-11638.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{zhang2021labeling,
title={Labeling Trick: A Theory of Using Graph Neural Networks for Multi-Node Representation Learning},
author={Zhang, Muhan and Li, Pan and Xia, Yinglong and Wang, Kai and Jin, Long},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
@inproceedings{zhang2018link,
title={Link prediction based on graph neural networks},
author={Zhang, Muhan and Chen, Yixin},
booktitle={Advances in Neural Information Processing Systems},
pages={5165--5175},
year={2018}
}
\ No newline at end of file
This diff is collapsed.
import math
import torch
import torch.nn.functional as F
from dgl.nn import GraphConv, SortPooling
from torch.nn import Conv1d, Embedding, Linear, MaxPool1d, ModuleList
class NGNN_GCNConv(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, output_channels):
super(NGNN_GCNConv, self).__init__()
self.conv = GraphConv(input_channels, hidden_channels)
self.fc = Linear(hidden_channels, hidden_channels)
self.fc2 = Linear(hidden_channels, output_channels)
def reset_parameters(self):
self.conv.reset_parameters()
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
for bias in [self.fc.bias, self.fc2.bias]:
stdv = 1.0 / math.sqrt(bias.size(0))
bias.data.uniform_(-stdv, stdv)
def forward(self, g, x, edge_weight=None):
x = self.conv(g, x, edge_weight)
# x = F.relu(x)
# x = self.fc(x)
x = F.relu(x)
x = self.fc2(x)
return x
# An end-to-end deep learning architecture for graph classification, AAAI-18.
class DGCNN(torch.nn.Module):
def __init__(
self,
hidden_channels,
num_layers,
max_z,
k,
feature_dim=0,
GNN=GraphConv,
NGNN=NGNN_GCNConv,
dropout=0.0,
ngnn_type="all",
):
super(DGCNN, self).__init__()
self.feature_dim = feature_dim
self.dropout = dropout
self.k = k
self.sort_pool = SortPooling(k=self.k)
self.max_z = max_z
self.z_embedding = Embedding(self.max_z, hidden_channels)
self.convs = ModuleList()
initial_channels = hidden_channels + self.feature_dim
if ngnn_type in ["input", "all"]:
self.convs.append(
NGNN(initial_channels, hidden_channels, hidden_channels)
)
else:
self.convs.append(GNN(initial_channels, hidden_channels))
if ngnn_type in ["hidden", "all"]:
for _ in range(0, num_layers - 1):
self.convs.append(
NGNN(hidden_channels, hidden_channels, hidden_channels)
)
else:
for _ in range(0, num_layers - 1):
self.convs.append(GNN(hidden_channels, hidden_channels))
if ngnn_type in ["output", "all"]:
self.convs.append(NGNN(hidden_channels, hidden_channels, 1))
else:
self.convs.append(GNN(hidden_channels, 1))
conv1d_channels = [16, 32]
total_latent_dim = hidden_channels * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
self.maxpool1d = MaxPool1d(2, 2)
self.conv2 = Conv1d(
conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((self.k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.lin1 = Linear(dense_dim, 128)
self.lin2 = Linear(128, 1)
def forward(self, g, z, x=None, edge_weight=None):
z_emb = self.z_embedding(z)
if z_emb.ndim == 3: # in case z has multiple integer labels
z_emb = z_emb.sum(dim=1)
if x is not None:
x = torch.cat([z_emb, x.to(torch.float)], 1)
else:
x = z_emb
xs = [x]
for conv in self.convs:
xs += [
F.dropout(
torch.tanh(conv(g, xs[-1], edge_weight=edge_weight)),
p=self.dropout,
training=self.training,
)
]
x = torch.cat(xs[1:], dim=-1)
# global pooling
x = self.sort_pool(g, x)
x = x.unsqueeze(1) # [num_graphs, 1, k * hidden]
x = F.relu(self.conv1(x))
x = self.maxpool1d(x)
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # [num_graphs, dense_dim]
# MLP.
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return x
import random
import sys
import numpy as np
import torch
from dgl.sampling import global_uniform_negative_sampling
from scipy.sparse.csgraph import shortest_path
def k_hop_subgraph(src, dst, num_hops, g, sample_ratio=1.0, directed=False):
# Extract the k-hop enclosing subgraph around link (src, dst) from g
nodes = [src, dst]
visited = set([src, dst])
fringe = set([src, dst])
for _ in range(num_hops):
if not directed:
_, fringe = g.out_edges(list(fringe))
fringe = fringe.tolist()
else:
_, out_neighbors = g.out_edges(list(fringe))
in_neighbors, _ = g.in_edges(list(fringe))
fringe = in_neighbors.tolist() + out_neighbors.tolist()
fringe = set(fringe) - visited
visited = visited.union(fringe)
if sample_ratio < 1.0:
fringe = random.sample(fringe, int(sample_ratio * len(fringe)))
if len(fringe) == 0:
break
nodes = nodes + list(fringe)
subg = g.subgraph(nodes, store_ids=True)
return subg
def drnl_node_labeling(adj, src, dst):
# Double Radius Node Labeling (DRNL).
src, dst = (dst, src) if src > dst else (src, dst)
idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
adj_wo_src = adj[idx, :][:, idx]
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=src
)
dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=dst - 1
)
dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = (
torch.div(dist, 2, rounding_mode="floor"),
dist % 2,
)
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1.0
z[dst] = 1.0
# shortest path may include inf values
z[torch.isnan(z)] = 0.0
return z.to(torch.long)
def get_pos_neg_edges(split, split_edge, g, percent=100):
pos_edge = split_edge[split]["edge"]
if split == "train":
neg_edge = torch.stack(
global_uniform_negative_sampling(
g, num_samples=pos_edge.size(0), exclude_self_loops=True
),
dim=1,
)
else:
neg_edge = split_edge[split]["edge_neg"]
# sampling according to the percent param
np.random.seed(123)
# pos sampling
num_pos = pos_edge.size(0)
perm = np.random.permutation(num_pos)
perm = perm[: int(percent / 100 * num_pos)]
pos_edge = pos_edge[perm]
# neg sampling
if neg_edge.dim() > 2: # [Np, Nn, 2]
neg_edge = neg_edge[perm].view(-1, 2)
else:
np.random.seed(123)
num_neg = neg_edge.size(0)
perm = np.random.permutation(num_neg)
perm = perm[: int(percent / 100 * num_neg)]
neg_edge = neg_edge[perm]
return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])
class Logger(object):
def __init__(self, runs, info=None):
self.info = info
self.results = {
"valid": [[] for _ in range(runs)],
"test": [[] for _ in range(runs)],
}
def add_result(self, run, result, split="valid"):
assert run >= 0 and run < len(self.results["valid"])
assert split in ["valid", "test"]
self.results[split][run].append(result)
def print_statistics(self, run=None, f=sys.stdout):
if run is not None:
result = torch.tensor(self.results["valid"][run])
print(f"Run {run + 1:02d}:", file=f)
print(f"Highest Valid: {result.max():.4f}", file=f)
print(f"Highest Eval Point: {result.argmax().item()+1}", file=f)
if not self.info.no_test:
print(
f' Final Test Point[1]: {self.results["test"][run][0][0]}',
f' Final Valid: {self.results["test"][run][0][1]}',
f' Final Test: {self.results["test"][run][0][2]}',
sep='\n',
file=f,
)
else:
best_result = torch.tensor(
[test_res[0] for test_res in self.results["test"]]
)
print(f"All runs:", file=f)
r = best_result[:, 1]
print(f"Highest Valid: {r.mean():.4f} ± {r.std():.4f}", file=f)
if not self.info.no_test:
r = best_result[:, 2]
print(f" Final Test: {r.mean():.4f} ± {r.std():.4f}", file=f)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment