Unverified Commit 92bff4d2 authored by Ereboas's avatar Ereboas Committed by GitHub
Browse files

[Example] SEAL+NGNN for ogbl (#4550)



* Use black for formatting

* limit line width to 80 characters.

* Use a backslash instead of directly concatenating

* file structure adjustment.

* file structure adjustment(2)
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent ec4271bf
# NGNN + SEAL
## Introduction
This is a submission of implementing [NGNN](https://arxiv.org/abs/2111.11638) + [SEAL](https://arxiv.org/pdf/2010.16103.pdf) to OGB link prediction leaderboards. Some code is migrated from [https://github.com/facebookresearch/SEAL_OGB](https://github.com/facebookresearch/SEAL_OGB).
## Installation Requirements
```
ogb>=1.3.4
torch>=1.12.0
dgl>=0.8
scipy, numpy, tqdm...
```
## Experiments
We do not fix random seeds at all, and take over 10 runs for all models. All models are trained on a single T4 GPU with 16GB memory and 96 vCPUs.
### ogbl-ppa
#### performance
| | Test Hits@100 | Validation Hits@100 | #Parameters |
|:------------:|:-------------------:|:-----------------:|:------------:|
| SEAL | 48.80% ± 3.16% | 51.25% ± 2.52% | 709,122 |
| SEAL + NGNN | 59.71% ± 2.45% | 59.95% ± 2.05% | 735,426 |
#### Reproduction of performance
```{.bash}
python main.py --dataset ogbl-ppa --ngnn_type input --hidden_channels 48 --epochs 50 --lr 0.00015 --batch_size 128 --num_workers 48 --train_percent 5 --val_percent 8 --eval_hits_K 10 --use_feature --dynamic_train --dynamic_val --dynamic_test --runs 10
```
As training is very costly, we select the best model by evaluation on a subset of the validation edges and using a lower K for Hits@K. Then we do experiments on the full validation and test sets with the best model selected, and get the required metrics.
For all datasets, if you specify `--dynamic_train`, the enclosing subgraphs of the training links will be extracted on the fly instead of preprocessing and saving to disk. Similarly for `--dynamic_val` and `--dynamic_test`. You can increase `--num_workers` to accelerate the dynamic subgraph extraction process.
You can also specify the `val_percent` and `eval_hits_K` arguments in the above command to adjust the proportion of the validation dataset to use and the K to use for Hits@K.
## Reference
@article{DBLP:journals/corr/abs-2111-11638,
author = {Xiang Song and
Runjie Ma and
Jiahang Li and
Muhan Zhang and
David Paul Wipf},
title = {Network In Graph Neural Network},
journal = {CoRR},
volume = {abs/2111.11638},
year = {2021},
url = {https://arxiv.org/abs/2111.11638},
eprinttype = {arXiv},
eprint = {2111.11638},
timestamp = {Fri, 26 Nov 2021 13:48:43 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2111-11638.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@article{zhang2021labeling,
title={Labeling Trick: A Theory of Using Graph Neural Networks for Multi-Node Representation Learning},
author={Zhang, Muhan and Li, Pan and Xia, Yinglong and Wang, Kai and Jin, Long},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
@inproceedings{zhang2018link,
title={Link prediction based on graph neural networks},
author={Zhang, Muhan and Chen, Yixin},
booktitle={Advances in Neural Information Processing Systems},
pages={5165--5175},
year={2018}
}
\ No newline at end of file
import argparse
import datetime
import os
import sys
import time
import dgl
import torch
from dgl.data.utils import load_graphs, save_graphs
from dgl.dataloading import GraphDataLoader
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import Dataset
from tqdm import tqdm
from models import *
from utils import *
class SEALOGBLDataset(Dataset):
def __init__(
self,
root,
graph,
split_edge,
percent=100,
split="train",
ratio_per_hop=1.0,
directed=False,
dynamic=True,
) -> None:
super().__init__()
self.root = root
self.graph = graph
self.split = split
self.split_edge = split_edge
self.percent = percent
self.ratio_per_hop = ratio_per_hop
self.directed = directed
self.dynamic = dynamic
if not self.dynamic:
self.g_list, tensor_dict = self.load_cached()
self.labels = tensor_dict["y"]
return
if "weights" in self.graph.edata:
self.edge_weights = self.graph.edata["weights"]
else:
self.edge_weights = None
if "feat" in self.graph.ndata:
self.node_features = self.graph.ndata["feat"]
else:
self.node_features = None
pos_edge, neg_edge = get_pos_neg_edges(
split, self.split_edge, self.graph, self.percent
)
self.links = torch.cat([pos_edge, neg_edge], 0).tolist() # [Np + Nn, 2]
self.labels = [1] * len(pos_edge) + [0] * len(neg_edge)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
if not self.dynamic:
g, y = self.g_list[idx], self.labels[idx]
x = None if "x" not in g.ndata else g.ndata["x"]
w = None if "w" not in g.edata else g.eata["w"]
return g, g.ndata["z"], x, w, y
src, dst = self.links[idx]
y = self.labels[idx]
subg = k_hop_subgraph(
src, dst, 1, self.graph, self.ratio_per_hop, self.directed
)
# Remove the link between src and dst.
direct_links = [[], []]
for s, t in [(0, 1), (1, 0)]:
if subg.has_edges_between(s, t):
direct_links[0].append(s)
direct_links[1].append(t)
if len(direct_links[0]):
subg.remove_edges(subg.edge_ids(*direct_links))
NIDs, EIDs = subg.ndata[dgl.NID], subg.edata[dgl.EID]
z = drnl_node_labeling(subg.adj(scipy_fmt="csr"), 0, 1)
edge_weights = (
self.edge_weights[EIDs] if self.edge_weights is not None else None
)
x = self.node_features[NIDs] if self.node_features is not None else None
subg_aug = subg.add_self_loop()
if edge_weights is not None:
edge_weights = torch.cat(
[
edge_weights,
torch.ones(subg_aug.num_edges() - subg.num_edges()),
]
)
return subg_aug, z, x, edge_weights, y
@property
def cached_name(self):
return f"SEAL_{self.split}_{self.percent}%.pt"
def process(self):
g_list, labels = [], []
self.dynamic = True
for i in tqdm(range(len(self))):
g, z, x, weights, y = self[i]
g.ndata["z"] = z
if x is not None:
g.ndata["x"] = x
if weights is not None:
g.edata["w"] = weights
g_list.append(g)
labels.append(y)
self.dynamic = False
return g_list, {"y": torch.tensor(labels)}
def load_cached(self):
path = os.path.join(self.root, self.cached_name)
if os.path.exists(path):
return load_graphs(path)
if not os.path.exists(self.root):
os.makedirs(self.root)
pos_edge, neg_edge = get_pos_neg_edges(
self.split, self.split_edge, self.graph, self.percent
)
self.links = torch.cat([pos_edge, neg_edge], 0).tolist() # [Np + Nn, 2]
self.labels = [1] * len(pos_edge) + [0] * len(neg_edge)
g_list, labels = self.process()
save_graphs(path, g_list, labels)
return g_list, labels
def ogbl_collate_fn(batch):
gs, zs, xs, ws, ys = zip(*batch)
batched_g = dgl.batch(gs)
z = torch.cat(zs, dim=0)
if xs[0] is not None:
x = torch.cat(xs, dim=0)
else:
x = None
if ws[0] is not None:
edge_weights = torch.cat(ws, dim=0)
else:
edge_weights = None
y = torch.tensor(ys)
return batched_g, z, x, edge_weights, y
def train():
model.train()
loss_fnt = BCEWithLogitsLoss()
total_loss = 0
pbar = tqdm(train_loader, ncols=70)
for batch in pbar:
g, z, x, edge_weights, y = [
item.to(device) if item is not None else None for item in batch
]
optimizer.zero_grad()
logits = model(g, z, x, edge_weight=edge_weights)
loss = loss_fnt(logits.view(-1), y.to(torch.float))
loss.backward()
optimizer.step()
total_loss += loss.item() * g.batch_size
return total_loss / len(train_dataset)
@torch.no_grad()
def test(dataloader, hits_K=["hits@100"]):
model.eval()
if isinstance(hits_K, (int, str)):
hits_K = [hits_K]
y_pred, y_true = [], []
for batch in tqdm(dataloader, ncols=70):
g, z, x, edge_weights, y = [
item.to(device) if item is not None else None for item in batch
]
logits = model(g, z, x, edge_weight=edge_weights)
y_pred.append(logits.view(-1).cpu())
y_true.append(y.view(-1).cpu().to(torch.float))
y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
pos_y_pred = y_pred[y_true == 1]
neg_y_pred = y_pred[y_true == 0]
if dataset.eval_metric.startswith("hits@"):
results = evaluate_hits(pos_y_pred, neg_y_pred, hits_K)
elif dataset.eval_metric == "mrr":
results = evaluate_mrr(pos_y_pred, neg_y_pred)
elif dataset.eval_metric == "rocauc":
results = evaluate_rocauc(pos_y_pred, neg_y_pred)
return results
def evaluate_hits(y_pred_pos, y_pred_neg, hits_K):
results = {}
hits_K = map(
lambda x: (int(x.split("@")[1]) if isinstance(x, str) else x), hits_K
)
for K in hits_K:
evaluator.K = K
hits = evaluator.eval(
{
"y_pred_pos": y_pred_pos,
"y_pred_neg": y_pred_neg,
}
)[f"hits@{K}"]
results[f"hits@{K}"] = hits
return results
def evaluate_mrr(y_pred_pos, y_pred_neg):
y_pred_neg = y_pred_neg.view(y_pred_pos.shape[0], -1)
results = {}
mrr = (
evaluator.eval(
{
"y_pred_pos": y_pred_pos,
"y_pred_neg": y_pred_neg,
}
)["mrr_list"]
.mean()
.item()
)
results["mrr"] = mrr
return results
def evaluate_rocauc(y_pred_pos, y_pred_neg):
results = {}
rocauc = evaluator.eval(
{
"y_pred_pos": y_pred_pos,
"y_pred_neg": y_pred_neg,
}
)["rocauc"]
results["rocauc"] = rocauc
return results
def print_log(*x, sep="\n", end="\n", mode="a"):
print(*x, sep=sep, end=end)
with open(log_file, mode=mode) as f:
print(*x, sep=sep, end=end, file=f)
if __name__ == "__main__":
# Data settings
parser = argparse.ArgumentParser(description="OGBL (SEAL)")
parser.add_argument("--dataset", type=str, default="ogbl-vessel")
# GNN settings
parser.add_argument(
"--max_z",
type=int,
default=1000,
help="max number of labels as embeddings to look up",
)
parser.add_argument("--sortpool_k", type=float, default=0.6)
parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument("--hidden_channels", type=int, default=32)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument(
"--ngnn_type",
type=str,
default="none",
choices=["none", "input", "hidden", "output", "all"],
help="You can set this value from 'none', 'input', 'hidden' or 'all' " \
"to apply NGNN to different GNN layers.",
)
# Subgraph extraction settings
parser.add_argument("--ratio_per_hop", type=float, default=1.0)
parser.add_argument(
"--use_feature",
action="store_true",
help="whether to use raw node features as GNN input",
)
parser.add_argument(
"--use_edge_weight",
action="store_true",
help="whether to consider edge weight in GNN",
)
# Training settings
parser.add_argument(
"--device",
type=int,
default=0,
help="GPU device ID. Use -1 for CPU training.",
)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--train_percent", type=float, default=1)
parser.add_argument("--val_percent", type=float, default=1)
parser.add_argument("--final_val_percent", type=float, default=1)
parser.add_argument("--test_percent", type=float, default=1)
parser.add_argument("--no_test", action="store_true")
parser.add_argument(
"--dynamic_train",
action="store_true",
help="dynamically extract enclosing subgraphs on the fly",
)
parser.add_argument("--dynamic_val", action="store_true")
parser.add_argument("--dynamic_test", action="store_true")
parser.add_argument(
"--num_workers",
type=int,
default=24,
help="number of workers for dynamic dataloaders; " \
"using a larger value for dynamic dataloading is recommended",
)
# Testing settings
parser.add_argument(
"--use_valedges_as_input",
action="store_true",
help="available for ogbl-collab",
)
parser.add_argument("--eval_steps", type=int, default=1)
parser.add_argument(
"--eval_hits_K",
type=int,
nargs="*",
default=[10],
help="hits@K for each eval step; " \
"only available for datasets with hits@xx as the eval metric",
)
parser.add_argument(
"--test_topk",
type=int,
default=1,
help="select best k models for full validation/test each run.",
)
args = parser.parse_args()
data_appendix = "_rph{}".format("".join(str(args.ratio_per_hop).split(".")))
if args.use_valedges_as_input:
data_appendix += "_uvai"
args.res_dir = os.path.join(
f'results{"_NoTest" if args.no_test else ""}',
f'{args.dataset.split("-")[1]}-{args.ngnn_type}+{time.strftime("%m%d%H%M%S")}'
)
print(f"Results will be saved in {args.res_dir}")
if not os.path.exists(args.res_dir):
os.makedirs(args.res_dir)
log_file = os.path.join(args.res_dir, "log.txt")
# Save command line input.
cmd_input = "python " + " ".join(sys.argv) + "\n"
with open(os.path.join(args.res_dir, "cmd_input.txt"), "a") as f:
f.write(cmd_input)
print(f"Command line input is saved.")
print_log(f"{cmd_input}")
dataset = DglLinkPropPredDataset(name=args.dataset)
split_edge = dataset.get_edge_split()
graph = dataset[0]
# Re-format the data of ogbl-citation2.
if args.dataset == "ogbl-citation2":
for k in ["train", "valid", "test"]:
src = split_edge[k]["source_node"]
tgt = split_edge[k]["target_node"]
split_edge[k]["edge"] = torch.stack([src, tgt], dim=1)
if k != "train":
tgt_neg = split_edge[k]["target_node_neg"]
split_edge[k]["edge_neg"] = torch.stack(
[src[:, None].repeat(1, tgt_neg.size(1)), tgt_neg], dim=-1
) # [Ns, Nt, 2]
# Reconstruct the graph for ogbl-collab data
# for validation edge augmentation and coalesce.
if args.dataset == "ogbl-collab":
# Float edata for to_simple transformation.
graph.edata.pop("year")
graph.edata["weight"] = graph.edata["weight"].to(torch.float)
if args.use_valedges_as_input:
val_edges = split_edge["valid"]["edge"]
row, col = val_edges.t()
val_weights = torch.ones(size=(val_edges.size(0), 1))
graph.add_edges(
torch.cat([row, col]),
torch.cat([col, row]),
{"weight": val_weights},
)
graph = graph.to_simple(copy_edata=True, aggregator="sum")
if args.dataset == "ogbl-vessel":
graph.ndata["feat"][:, 0] = torch.nn.functional.normalize(
graph.ndata["feat"][:, 0], dim=0
)
graph.ndata["feat"][:, 1] = torch.nn.functional.normalize(
graph.ndata["feat"][:, 1], dim=0
)
graph.ndata["feat"][:, 2] = torch.nn.functional.normalize(
graph.ndata["feat"][:, 2], dim=0
)
graph.ndata["feat"] = graph.ndata["feat"].to(torch.float)
if not args.use_edge_weight and "weight" in graph.edata:
del graph.edata["weight"]
if not args.use_feature and "feat" in graph.ndata:
del graph.ndata["feat"]
directed = args.dataset.startswith("ogbl-citation")
evaluator = Evaluator(name=args.dataset)
if dataset.eval_metric.startswith("hits@"):
loggers = {
f"hits@{k}": Logger(args.runs, args) for k in args.eval_hits_K
}
elif dataset.eval_metric == "mrr":
loggers = {
"mrr": Logger(args.runs, args),
}
elif dataset.eval_metric == "rocauc":
loggers = {
"rocauc": Logger(args.runs, args),
}
device = (
f"cuda:{args.device}"
if args.device != -1 and torch.cuda.is_available()
else "cpu"
)
device = torch.device(device)
path = f"{dataset.root}_seal{data_appendix}"
if not (args.dynamic_train or args.dynamic_val or args.dynamic_test):
args.num_workers = 0
train_dataset, val_dataset, final_val_dataset, test_dataset = [
SEALOGBLDataset(
path,
graph,
split_edge,
percent=percent,
split=split,
ratio_per_hop=args.ratio_per_hop,
directed=directed,
dynamic=dynamic,
)
for percent, split, dynamic in zip(
[
args.train_percent,
args.val_percent,
args.final_val_percent,
args.test_percent,
],
["train", "valid", "valid", "test"],
[
args.dynamic_train,
args.dynamic_val,
args.dynamic_test,
args.dynamic_test,
],
)
]
train_loader = GraphDataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=ogbl_collate_fn,
num_workers=args.num_workers,
)
val_loader = GraphDataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=ogbl_collate_fn,
num_workers=args.num_workers,
)
final_val_loader = GraphDataLoader(
final_val_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=ogbl_collate_fn,
num_workers=args.num_workers,
)
test_loader = GraphDataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=ogbl_collate_fn,
num_workers=args.num_workers,
)
if 0 < args.sortpool_k <= 1: # Transform percentile to number.
if args.dynamic_train:
_sampled_indices = range(1000)
#_sampled_indices = np.random.choice(
# len(train_dataset), 1000, replace=False
# )
else:
_sampled_indices = range(len(train_dataset))
_num_nodes = sorted(
[train_dataset[i][0].num_nodes() for i in _sampled_indices]
)
_k = _num_nodes[int(math.ceil(args.sortpool_k * len(_num_nodes))) - 1]
model_k = max(10, _k)
else:
raise argparse.ArgumentTypeError("sortpool_k must be in range (0, 1].")
print_log(f"training starts: {datetime.datetime.now()}")
for run in range(args.runs):
stime = datetime.datetime.now()
print_log(f"\n++++++\n\nstart run [{run+1}], {stime}")
model = DGCNN(
args.hidden_channels,
args.num_layers,
args.max_z,
model_k,
feature_dim=graph.ndata["feat"].size(1)
if (args.use_feature and "feat" in graph.ndata)
else 0,
dropout=args.dropout,
ngnn_type=args.ngnn_type,
).to(device)
parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=args.lr)
total_params = sum(p.numel() for param in parameters for p in param)
print_log(
f"Total number of parameters is {total_params}",
f"SortPooling k is set to {model.k}",
)
start_epoch = 1
# Training starts.
for epoch in range(start_epoch, start_epoch + args.epochs):
epo_stime = datetime.datetime.now()
loss = train()
epo_train_etime = datetime.datetime.now()
print_log(
f"[epoch: {epoch}]",
f" <Train> starts: {epo_stime}, " \
f"ends: {epo_train_etime}, " \
f"spent time:{epo_train_etime - epo_stime}"
)
if epoch % args.eval_steps == 0:
epo_eval_stime = datetime.datetime.now()
results = test(val_loader, loggers.keys())
epo_eval_etime = datetime.datetime.now()
print_log(
f" <Validation> starts: {epo_eval_stime}, " \
f"ends: {epo_eval_etime}, " \
f"spent time:{epo_eval_etime - epo_eval_stime}"
)
for key, valid_res in results.items():
loggers[key].add_result(run, valid_res)
to_print = (
f"Run: {run + 1:02d}, " \
f"Epoch: {epoch:02d}, " \
f"Loss: {loss:.4f}, " \
f"Valid ({args.val_percent}%) [{key}]: {valid_res:.4f}"
)
print_log(key, to_print)
model_name = os.path.join(
args.res_dir, f"run{run+1}_model_checkpoint{epoch}.pth"
)
optimizer_name = os.path.join(
args.res_dir, f"run{run+1}_optimizer_checkpoint{epoch}.pth"
)
torch.save(model.state_dict(), model_name)
torch.save(optimizer.state_dict(), optimizer_name)
print_log()
tested = dict()
for eval_metric in loggers.keys():
# Select models according to the eval_metric of the dataset.
res = torch.tensor(
loggers[eval_metric].results["valid"][run]
)
if args.no_test:
epoch = torch.argmax(res).item() + 1
val_res = loggers[eval_metric].results["valid"][run][epoch - 1]
loggers[eval_metric].add_result(run, (epoch, val_res), "test")
print_log(
f"No Test; Best Valid:",
f" Run: {run + 1:02d}, " \
f"Epoch: {epoch:02d}, " \
f"Valid ({args.val_percent}%) [{eval_metric}]: {val_res:.4f}",
)
continue
idx_to_test = (
torch.topk(res, args.test_topk, largest=True).indices + 1
).tolist() # indices of top k valid results
print_log(
f"Eval Metric: {eval_metric}",
f"Run: {run + 1:02d}, " \
f"Top {args.test_topk} Eval Points: {idx_to_test}",
)
for _idx, epoch in enumerate(idx_to_test):
print_log(
f"Test Point[{_idx+1}]: " \
f"Epoch {epoch:02d}, " \
f"Test Metric: {dataset.eval_metric}"
)
if epoch not in tested:
model_name = os.path.join(
args.res_dir, f"run{run+1}_model_checkpoint{epoch}.pth"
)
optimizer_name = os.path.join(
args.res_dir,
f"run{run+1}_optimizer_checkpoint{epoch}.pth",
)
model.load_state_dict(torch.load(model_name))
optimizer.load_state_dict(torch.load(optimizer_name))
tested[epoch] = (
test(final_val_loader, dataset.eval_metric)[
dataset.eval_metric
],
test(test_loader, dataset.eval_metric)[
dataset.eval_metric
],
)
val_res, test_res = tested[epoch]
loggers[eval_metric].add_result(
run, (epoch, val_res, test_res), "test"
)
print_log(
f" Run: {run + 1:02d}, " \
f"Epoch: {epoch:02d}, " \
f"Valid ({args.val_percent}%) [{eval_metric}]: " \
f"{loggers[eval_metric].results['valid'][run][epoch-1]:.4f}, " \
f"Valid (final) [{dataset.eval_metric}]: {val_res:.4f}, " \
f"Test [{dataset.eval_metric}]: {test_res:.4f}"
)
etime = datetime.datetime.now()
print_log(
f"end run [{run}], {etime}",
f"spent time:{etime-stime}",
)
for key in loggers.keys():
print(f"\n{key}")
loggers[key].print_statistics()
with open(log_file, "a") as f:
print(f"\n{key}", file=f)
loggers[key].print_statistics(f=f)
print(f"Total number of parameters is {total_params}")
print(f"Results are saved in {args.res_dir}")
import math
import torch
import torch.nn.functional as F
from dgl.nn import GraphConv, SortPooling
from torch.nn import Conv1d, Embedding, Linear, MaxPool1d, ModuleList
class NGNN_GCNConv(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, output_channels):
super(NGNN_GCNConv, self).__init__()
self.conv = GraphConv(input_channels, hidden_channels)
self.fc = Linear(hidden_channels, hidden_channels)
self.fc2 = Linear(hidden_channels, output_channels)
def reset_parameters(self):
self.conv.reset_parameters()
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
torch.nn.init.xavier_uniform_(self.fc2.weight, gain=gain)
for bias in [self.fc.bias, self.fc2.bias]:
stdv = 1.0 / math.sqrt(bias.size(0))
bias.data.uniform_(-stdv, stdv)
def forward(self, g, x, edge_weight=None):
x = self.conv(g, x, edge_weight)
# x = F.relu(x)
# x = self.fc(x)
x = F.relu(x)
x = self.fc2(x)
return x
# An end-to-end deep learning architecture for graph classification, AAAI-18.
class DGCNN(torch.nn.Module):
def __init__(
self,
hidden_channels,
num_layers,
max_z,
k,
feature_dim=0,
GNN=GraphConv,
NGNN=NGNN_GCNConv,
dropout=0.0,
ngnn_type="all",
):
super(DGCNN, self).__init__()
self.feature_dim = feature_dim
self.dropout = dropout
self.k = k
self.sort_pool = SortPooling(k=self.k)
self.max_z = max_z
self.z_embedding = Embedding(self.max_z, hidden_channels)
self.convs = ModuleList()
initial_channels = hidden_channels + self.feature_dim
if ngnn_type in ["input", "all"]:
self.convs.append(
NGNN(initial_channels, hidden_channels, hidden_channels)
)
else:
self.convs.append(GNN(initial_channels, hidden_channels))
if ngnn_type in ["hidden", "all"]:
for _ in range(0, num_layers - 1):
self.convs.append(
NGNN(hidden_channels, hidden_channels, hidden_channels)
)
else:
for _ in range(0, num_layers - 1):
self.convs.append(GNN(hidden_channels, hidden_channels))
if ngnn_type in ["output", "all"]:
self.convs.append(NGNN(hidden_channels, hidden_channels, 1))
else:
self.convs.append(GNN(hidden_channels, 1))
conv1d_channels = [16, 32]
total_latent_dim = hidden_channels * num_layers + 1
conv1d_kws = [total_latent_dim, 5]
self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0], conv1d_kws[0])
self.maxpool1d = MaxPool1d(2, 2)
self.conv2 = Conv1d(
conv1d_channels[0], conv1d_channels[1], conv1d_kws[1], 1
)
dense_dim = int((self.k - 2) / 2 + 1)
dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
self.lin1 = Linear(dense_dim, 128)
self.lin2 = Linear(128, 1)
def forward(self, g, z, x=None, edge_weight=None):
z_emb = self.z_embedding(z)
if z_emb.ndim == 3: # in case z has multiple integer labels
z_emb = z_emb.sum(dim=1)
if x is not None:
x = torch.cat([z_emb, x.to(torch.float)], 1)
else:
x = z_emb
xs = [x]
for conv in self.convs:
xs += [
F.dropout(
torch.tanh(conv(g, xs[-1], edge_weight=edge_weight)),
p=self.dropout,
training=self.training,
)
]
x = torch.cat(xs[1:], dim=-1)
# global pooling
x = self.sort_pool(g, x)
x = x.unsqueeze(1) # [num_graphs, 1, k * hidden]
x = F.relu(self.conv1(x))
x = self.maxpool1d(x)
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # [num_graphs, dense_dim]
# MLP.
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return x
import random
import sys
import numpy as np
import torch
from dgl.sampling import global_uniform_negative_sampling
from scipy.sparse.csgraph import shortest_path
def k_hop_subgraph(src, dst, num_hops, g, sample_ratio=1.0, directed=False):
# Extract the k-hop enclosing subgraph around link (src, dst) from g
nodes = [src, dst]
visited = set([src, dst])
fringe = set([src, dst])
for _ in range(num_hops):
if not directed:
_, fringe = g.out_edges(list(fringe))
fringe = fringe.tolist()
else:
_, out_neighbors = g.out_edges(list(fringe))
in_neighbors, _ = g.in_edges(list(fringe))
fringe = in_neighbors.tolist() + out_neighbors.tolist()
fringe = set(fringe) - visited
visited = visited.union(fringe)
if sample_ratio < 1.0:
fringe = random.sample(fringe, int(sample_ratio * len(fringe)))
if len(fringe) == 0:
break
nodes = nodes + list(fringe)
subg = g.subgraph(nodes, store_ids=True)
return subg
def drnl_node_labeling(adj, src, dst):
# Double Radius Node Labeling (DRNL).
src, dst = (dst, src) if src > dst else (src, dst)
idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
adj_wo_src = adj[idx, :][:, idx]
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(
adj_wo_dst, directed=False, unweighted=True, indices=src
)
dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(
adj_wo_src, directed=False, unweighted=True, indices=dst - 1
)
dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = (
torch.div(dist, 2, rounding_mode="floor"),
dist % 2,
)
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1.0
z[dst] = 1.0
# shortest path may include inf values
z[torch.isnan(z)] = 0.0
return z.to(torch.long)
def get_pos_neg_edges(split, split_edge, g, percent=100):
pos_edge = split_edge[split]["edge"]
if split == "train":
neg_edge = torch.stack(
global_uniform_negative_sampling(
g, num_samples=pos_edge.size(0), exclude_self_loops=True
),
dim=1,
)
else:
neg_edge = split_edge[split]["edge_neg"]
# sampling according to the percent param
np.random.seed(123)
# pos sampling
num_pos = pos_edge.size(0)
perm = np.random.permutation(num_pos)
perm = perm[: int(percent / 100 * num_pos)]
pos_edge = pos_edge[perm]
# neg sampling
if neg_edge.dim() > 2: # [Np, Nn, 2]
neg_edge = neg_edge[perm].view(-1, 2)
else:
np.random.seed(123)
num_neg = neg_edge.size(0)
perm = np.random.permutation(num_neg)
perm = perm[: int(percent / 100 * num_neg)]
neg_edge = neg_edge[perm]
return pos_edge, neg_edge # ([2, Np], [2, Nn]) -> ([Np, 2], [Nn, 2])
class Logger(object):
def __init__(self, runs, info=None):
self.info = info
self.results = {
"valid": [[] for _ in range(runs)],
"test": [[] for _ in range(runs)],
}
def add_result(self, run, result, split="valid"):
assert run >= 0 and run < len(self.results["valid"])
assert split in ["valid", "test"]
self.results[split][run].append(result)
def print_statistics(self, run=None, f=sys.stdout):
if run is not None:
result = torch.tensor(self.results["valid"][run])
print(f"Run {run + 1:02d}:", file=f)
print(f"Highest Valid: {result.max():.4f}", file=f)
print(f"Highest Eval Point: {result.argmax().item()+1}", file=f)
if not self.info.no_test:
print(
f' Final Test Point[1]: {self.results["test"][run][0][0]}',
f' Final Valid: {self.results["test"][run][0][1]}',
f' Final Test: {self.results["test"][run][0][2]}',
sep='\n',
file=f,
)
else:
best_result = torch.tensor(
[test_res[0] for test_res in self.results["test"]]
)
print(f"All runs:", file=f)
r = best_result[:, 1]
print(f"Highest Valid: {r.mean():.4f} ± {r.std():.4f}", file=f)
if not self.info.no_test:
r = best_result[:, 2]
print(f" Final Test: {r.mean():.4f} ± {r.std():.4f}", file=f)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment