"vscode:/vscode.git/clone" did not exist on "c991ffd4f070beb74d4281fba2ee8c49c82d69b7"
Unverified Commit bdbc0178 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Example] Labor sampling (#4718)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent d7410cf4
......@@ -5,7 +5,10 @@ The folder contains example implementations of selected research papers related
* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)
To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).
## 2022
- <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/labor/train_lightning.py)
- Tags: node classification, weighted graphs, sampling
## 2021
- <a name="rnaglib"></a> Mallet et al. Learning Protein and Small Molecule binding sites in RNA molecules with 2.5D graphs. [Paper link](https://academic.oup.com/bioinformatics/article/38/5/1458/6462185?login=true)
- Example code: [PyTorch](https://jwgitlab.cs.mcgill.ca/cgoliver/rnaglib)
......
# referenced the following implementation: https://github.com/BarclayII/dgl/blob/ladies/examples/pytorch/ladies/ladies2.py
import dgl
import dgl.function as fn
import torch
def find_indices_in(a, b):
b_sorted, indices = torch.sort(b)
sorted_indices = torch.searchsorted(b_sorted, a)
sorted_indices[sorted_indices >= indices.shape[0]] = 0
return indices[sorted_indices]
def union(*arrays):
return torch.unique(torch.cat(arrays))
def normalized_edata(g, weight=None):
with g.local_scope():
if weight is None:
weight = "W"
g.edata[weight] = torch.ones(g.number_of_edges(), device=g.device)
g.update_all(fn.copy_e(weight, weight), fn.sum(weight, "v"))
g.apply_edges(lambda edges: {"w": 1 / edges.dst["v"]})
return g.edata["w"]
class LadiesSampler(dgl.dataloading.BlockSampler):
def __init__(
self,
nodes_per_layer,
importance_sampling=True,
weight="w",
out_weight="edge_weights",
replace=False,
):
super().__init__()
self.nodes_per_layer = nodes_per_layer
self.importance_sampling = importance_sampling
self.edge_weight = weight
self.output_weight = out_weight
self.replace = replace
def compute_prob(self, g, seed_nodes, weight, num):
"""
g : the whole graph
seed_nodes : the output nodes for the current layer
weight : the weight of the edges
return : the unnormalized probability of the candidate nodes, as well as the subgraph
containing all the edges from the candidate nodes to the output nodes.
"""
insg = dgl.in_subgraph(g, seed_nodes)
insg = dgl.compact_graphs(insg, seed_nodes)
if self.importance_sampling:
out_frontier = dgl.reverse(insg, copy_edata=True)
weight = weight[out_frontier.edata[dgl.EID].long()]
prob = dgl.ops.copy_e_sum(out_frontier, weight**2)
# prob = torch.sqrt(prob)
else:
prob = torch.ones(insg.num_nodes())
prob[insg.out_degrees() == 0] = 0
return prob, insg
def select_neighbors(self, prob, num):
"""
seed_nodes : output nodes
cand_nodes : candidate nodes. Must contain all output nodes in @seed_nodes
prob : unnormalized probability of each candidate node
num : number of neighbors to sample
return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of
seed nodes in the selected nodes.
"""
# The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.
# Because compute_prob returns a compacted subgraph and a list of probabilities,
# we need to find the corresponding local IDs of the resulting union in the subgraph
# so that we can compute the edge weights of the block.
# This is why we need a find_indices_in() function.
neighbor_nodes_idx = torch.multinomial(
prob, min(num, prob.shape[0]), replacement=self.replace
)
return neighbor_nodes_idx
def generate_block(self, insg, neighbor_nodes_idx, seed_nodes, P_sg, W_sg):
"""
insg : the subgraph yielded by compute_prob()
neighbor_nodes_idx : the sampled nodes from the subgraph @insg, yielded by select_neighbors()
seed_nodes_local_idx : the indices of seed nodes in the selected neighbor nodes, also yielded
by select_neighbors()
P_sg : unnormalized probability of each node being sampled, yielded by compute_prob()
W_sg : edge weights of @insg
return : the block.
"""
seed_nodes_idx = find_indices_in(seed_nodes, insg.ndata[dgl.NID])
u_nodes = union(neighbor_nodes_idx, seed_nodes_idx)
sg = insg.subgraph(u_nodes.type(insg.idtype))
u, v = sg.edges()
lu = sg.ndata[dgl.NID][u.long()]
s = find_indices_in(lu, neighbor_nodes_idx)
eg = dgl.edge_subgraph(
sg, lu == neighbor_nodes_idx[s], relabel_nodes=False
)
eg.ndata[dgl.NID] = sg.ndata[dgl.NID][: eg.num_nodes()]
eg.edata[dgl.EID] = sg.edata[dgl.EID][eg.edata[dgl.EID].long()]
sg = eg
nids = insg.ndata[dgl.NID][sg.ndata[dgl.NID].long()]
P = P_sg[u_nodes.long()]
W = W_sg[sg.edata[dgl.EID].long()]
W_tilde = dgl.ops.e_div_u(sg, W, P)
W_tilde_sum = dgl.ops.copy_e_sum(sg, W_tilde)
d = sg.in_degrees()
W_tilde = dgl.ops.e_mul_v(sg, W_tilde, d / W_tilde_sum)
block = dgl.to_block(sg, seed_nodes_idx.type(sg.idtype))
block.edata[self.output_weight] = W_tilde
# correct node ID mapping
block.srcdata[dgl.NID] = nids[block.srcdata[dgl.NID].long()]
block.dstdata[dgl.NID] = nids[block.dstdata[dgl.NID].long()]
sg_eids = insg.edata[dgl.EID][sg.edata[dgl.EID].long()]
block.edata[dgl.EID] = sg_eids[block.edata[dgl.EID].long()]
return block
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_nodes = seed_nodes
blocks = []
for block_id in reversed(range(len(self.nodes_per_layer))):
num_nodes_to_sample = self.nodes_per_layer[block_id]
W = g.edata[self.edge_weight]
prob, insg = self.compute_prob(
g, seed_nodes, W, num_nodes_to_sample
)
neighbor_nodes_idx = self.select_neighbors(
prob, num_nodes_to_sample
)
block = self.generate_block(
insg,
neighbor_nodes_idx.type(g.idtype),
seed_nodes.type(g.idtype),
prob,
W[insg.edata[dgl.EID].long()],
)
seed_nodes = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seed_nodes, output_nodes, blocks
class PoissonLadiesSampler(LadiesSampler):
def __init__(
self,
nodes_per_layer,
importance_sampling=True,
weight="w",
out_weight="edge_weights",
skip=False,
):
super().__init__(
nodes_per_layer, importance_sampling, weight, out_weight
)
self.eps = 0.9999
self.skip = skip
def compute_prob(self, g, seed_nodes, weight, num):
"""
g : the whole graph
seed_nodes : the output nodes for the current layer
weight : the weight of the edges
return : the unnormalized probability of the candidate nodes, as well as the subgraph
containing all the edges from the candidate nodes to the output nodes.
"""
prob, insg = super().compute_prob(g, seed_nodes, weight, num)
one = torch.ones_like(prob)
if prob.shape[0] <= num:
return one, insg
c = 1.0
for i in range(50):
S = torch.sum(torch.minimum(prob * c, one).to(torch.float64)).item()
if min(S, num) / max(S, num) >= self.eps:
break
else:
c *= num / S
if self.skip:
skip_nodes = find_indices_in(seed_nodes, insg.ndata[dgl.NID])
prob[skip_nodes] = float("inf")
return torch.minimum(prob * c, one), insg
def select_neighbors(self, prob, num):
"""
seed_nodes : output nodes
cand_nodes : candidate nodes. Must contain all output nodes in @seed_nodes
prob : unnormalized probability of each candidate node
num : number of neighbors to sample
return : the set of input nodes in terms of their indices in @cand_nodes, and also the indices of
seed nodes in the selected nodes.
"""
# The returned nodes should be a union of seed_nodes plus @num nodes from cand_nodes.
# Because compute_prob returns a compacted subgraph and a list of probabilities,
# we need to find the corresponding local IDs of the resulting union in the subgraph
# so that we can compute the edge weights of the block.
# This is why we need a find_indices_in() function.
neighbor_nodes_idx = torch.arange(prob.shape[0], device=prob.device)[
torch.bernoulli(prob) == 1
]
return neighbor_nodes_idx
import dgl
import torch as th
def load_data(data):
g = data[0]
g.ndata["features"] = g.ndata.pop("feat")
g.ndata["labels"] = g.ndata.pop("label")
return g, data.num_classes
def load_dgl(name):
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
FlickrDataset,
PubmedGraphDataset,
RedditDataset,
YelpDataset,
)
d = {
"cora": CoraGraphDataset,
"citeseer": CiteseerGraphDataset,
"pubmed": PubmedGraphDataset,
"reddit": RedditDataset,
"yelp": YelpDataset,
"flickr": FlickrDataset,
}
return load_data(d[name]())
def load_reddit(self_loop=True):
from dgl.data import RedditDataset
# load reddit data
data = RedditDataset(self_loop=self_loop)
return load_data(data)
def load_mag240m(root="dataset"):
from os.path import join
import numpy as np
from ogb.lsc import MAG240MDataset
dataset = MAG240MDataset(root=root)
print("Loading graph")
(g,), _ = dgl.load_graphs(join(root, "mag240m_kddcup2021/graph.dgl"))
print("Loading features")
paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features
feats = th.from_numpy(
np.memmap(
join(root, "mag240m_kddcup2021/full.npy"),
mode="r",
dtype="float16",
shape=(num_nodes, num_features),
)
).float()
g.ndata["features"] = feats
train_nid = th.LongTensor(dataset.get_idx_split("train")) + paper_offset
val_nid = th.LongTensor(dataset.get_idx_split("valid")) + paper_offset
test_nid = th.LongTensor(dataset.get_idx_split("test-dev")) + paper_offset
train_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
g.ndata["train_mask"] = train_mask
g.ndata["val_mask"] = val_mask
g.ndata["test_mask"] = test_mask
labels = th.tensor(dataset.paper_label)
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
g.ndata["labels"] = -th.ones(g.number_of_nodes(), dtype=th.int64)
g.ndata["labels"][train_nid] = labels[train_nid - paper_offset].long()
g.ndata["labels"][val_nid] = labels[val_nid - paper_offset].long()
return g, num_labels
def load_ogb(name, root="dataset"):
if name == "ogbn-mag240M":
return load_mag240m(root)
from ogb.nodeproppred import DglNodePropPredDataset
print("load", name)
data = DglNodePropPredDataset(name=name, root=root)
print("finish loading", name)
splitted_idx = data.get_idx_split()
graph, labels = data[0]
labels = labels[:, 0]
graph.ndata["features"] = graph.ndata.pop("feat")
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
graph.ndata["labels"] = labels.type(th.LongTensor)
in_feats = graph.ndata["features"].shape[1]
# Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
graph.ndata["train_mask"] = train_mask
graph.ndata["val_mask"] = val_mask
graph.ndata["test_mask"] = test_mask
print("finish constructing", name)
return graph, num_labels
def load_dataset(dataset_name):
multilabel = False
if dataset_name in [
"reddit",
"cora",
"citeseer",
"pubmed",
"yelp",
"flickr",
]:
g, n_classes = load_dgl(dataset_name)
multilabel = dataset_name in ["yelp"]
if multilabel:
g.ndata["labels"] = g.ndata["labels"].to(dtype=th.float32)
elif dataset_name in [
"ogbn-products",
"ogbn-arxiv",
"ogbn-papers100M",
"ogbn-mag240M",
]:
g, n_classes = load_ogb(dataset_name)
else:
raise ValueError("unknown dataset")
return g, n_classes, multilabel
import dgl
import dgl.nn as dglnn
import sklearn.linear_model as lm
import sklearn.metrics as skm
import torch as th
import torch.functional as F
import torch.nn as nn
from dgl.nn import GATv2Conv
class GATv2(nn.Module):
def __init__(
self,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
):
super(GATv2, self).__init__()
self.num_layers = num_layers
self.gatv2_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gatv2_layers.append(
GATv2Conv(
in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# hidden layers
for l in range(1, num_layers - 1):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# output projection
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
True,
bias=False,
share_weights=True,
)
)
def forward(self, mfgs, h):
for l, mfg in enumerate(mfgs):
h = self.gatv2_layers[l](mfg, h)
h = h.flatten(1) if l < self.num_layers - 1 else h.mean(1)
return h
class SAGE(nn.Module):
def __init__(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super().__init__()
self.init(in_feats, n_hidden, n_classes, n_layers, activation, dropout)
def init(
self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
if n_layers > 1:
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
else:
self.layers.append(dglnn.SAGEConv(in_feats, n_classes, "mean"))
self.dropout = nn.Dropout(dropout)
self.activation = activation
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(
block,
h,
edge_weight=block.edata["edge_weights"]
if "edge_weights" in block.edata
else None,
)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
return h
class RGAT(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels),
)
self.dropout = nn.Dropout(dropout)
self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[mfg.dst_in_src]
for data in [mfg.srcdata, mfg.dstdata]:
for k in list(data.keys()):
if k not in ["features", "labels"]:
data.pop(k)
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(
mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = th.nn.functional.elu(x)
x = self.dropout(x)
return self.mlp(x)
# /*!
# * Copyright (c) 2022, NVIDIA Corporation
# * Copyright (c) 2022, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
# * All rights reserved.
# *
# * Licensed under the Apache License, Version 2.0 (the "License");
# * you may not use this file except in compliance with the License.
# * You may obtain a copy of the License at
# *
# * http://www.apache.org/licenses/LICENSE-2.0
# *
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS,
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * See the License for the specific language governing permissions and
# * limitations under the License.
# *
# * @file train_lightning.py
# * @brief labor sampling example
# */
import argparse
import glob
import math
import os
import time
import dgl
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ladies_sampler import LadiesSampler, normalized_edata, PoissonLadiesSampler
from load_graph import load_dataset
from model import GATv2, RGAT, SAGE
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics.classification import MulticlassF1Score, MultilabelF1Score
def cuda_index_tensor(tensor, idx):
assert idx.device != th.device("cpu")
if tensor.is_pinned():
return dgl.utils.gather_pinned_tensor_rows(tensor, idx)
else:
return tensor[idx.long()]
class SAGELightning(LightningModule):
def __init__(
self,
in_feats,
n_hidden,
n_classes,
n_layers,
model,
activation,
dropout,
lr,
multilabel,
):
super().__init__()
self.save_hyperparameters()
if model in ["sage"]:
self.module = (
SAGE(
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
if in_feats != 768
else RGAT(
in_feats,
n_classes,
n_hidden,
5,
n_layers,
4,
args.dropout,
"paper",
)
)
else:
heads = ([8] * n_layers) + [1]
self.module = GATv2(
n_layers,
in_feats,
n_hidden,
n_classes,
heads,
activation,
dropout,
dropout,
0.2,
True,
)
self.lr = lr
f1score_class = (
MulticlassF1Score if not multilabel else MultilabelF1Score
)
self.train_acc = f1score_class(n_classes, average="micro")
self.val_acc = nn.ModuleList(
[
f1score_class(n_classes, average="micro"),
f1score_class(n_classes, average="micro"),
]
)
self.test_acc = nn.ModuleList(
[
f1score_class(n_classes, average="micro"),
f1score_class(n_classes, average="micro"),
]
)
self.num_steps = 0
self.cum_sampled_nodes = [0 for _ in range(n_layers + 1)]
self.cum_sampled_edges = [0 for _ in range(n_layers)]
self.w = 0.99
self.loss_fn = (
nn.CrossEntropyLoss() if not multilabel else nn.BCEWithLogitsLoss()
)
self.pt = 0
def num_sampled_nodes(self, i):
return (
self.cum_sampled_nodes[i] / self.num_steps
if self.w >= 1
else self.cum_sampled_nodes[i]
* (1 - self.w)
/ (1 - self.w**self.num_steps)
)
def num_sampled_edges(self, i):
return (
self.cum_sampled_edges[i] / self.num_steps
if self.w >= 1
else self.cum_sampled_edges[i]
* (1 - self.w)
/ (1 - self.w**self.num_steps)
)
def training_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
self.num_steps += 1
for i, mfg in enumerate(mfgs):
self.cum_sampled_nodes[i] = (
self.cum_sampled_nodes[i] * self.w + mfg.num_src_nodes()
)
self.cum_sampled_edges[i] = (
self.cum_sampled_edges[i] * self.w + mfg.num_edges()
)
self.log(
"num_nodes/{}".format(i),
self.num_sampled_nodes(i),
prog_bar=True,
on_step=True,
on_epoch=False,
)
self.log(
"num_edges/{}".format(i),
self.num_sampled_edges(i),
prog_bar=True,
on_step=True,
on_epoch=False,
)
# for batch size monitoring
i = len(mfgs)
self.cum_sampled_nodes[i] = (
self.cum_sampled_nodes[i] * self.w + mfgs[-1].num_dst_nodes()
)
self.log(
"num_nodes/{}".format(i),
self.num_sampled_nodes(i),
prog_bar=True,
on_step=True,
on_epoch=False,
)
batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata["labels"]
self.st = time.time()
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels)
self.train_acc(batch_pred, batch_labels.int())
self.log(
"train_acc",
self.train_acc,
prog_bar=True,
on_step=True,
on_epoch=True,
batch_size=batch_labels.shape[0],
)
self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
batch_size=batch_labels.shape[0],
)
t = time.time()
self.log(
"iter_time",
t - self.pt,
prog_bar=True,
on_step=True,
on_epoch=False,
)
self.pt = t
return loss
def on_train_batch_end(self, outputs, batch, batch_idx):
self.log(
"forward_backward_time",
time.time() - self.st,
prog_bar=True,
on_step=True,
on_epoch=False,
)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels)
self.val_acc[dataloader_idx](batch_pred, batch_labels.int())
self.log(
"val_acc",
self.val_acc[dataloader_idx],
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
self.log(
"val_loss",
loss,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
def test_step(self, batch, batch_idx, dataloader_idx=0):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels)
self.test_acc[dataloader_idx](batch_pred, batch_labels.int())
self.log(
"test_acc",
self.test_acc[dataloader_idx],
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
self.log(
"test_loss",
loss,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
class DataModule(LightningDataModule):
def __init__(
self,
dataset_name,
undirected,
data_cpu=False,
use_uva=False,
fan_out=[10, 25],
lad_out=[11000, 5000],
device=th.device("cpu"),
batch_size=1000,
num_workers=4,
sampler="labor",
importance_sampling=0,
layer_dependency=False,
batch_dependency=1,
cache_size=0,
):
super().__init__()
g, n_classes, multilabel = load_dataset(dataset_name)
if undirected:
src, dst = g.all_edges()
g.add_edges(dst, src)
cast_to_int = max(g.num_nodes(), g.num_edges()) <= 2e9
if cast_to_int:
g = g.int()
train_nid = th.nonzero(g.ndata["train_mask"], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata["val_mask"], as_tuple=True)[0]
test_nid = th.nonzero(g.ndata["test_mask"], as_tuple=True)[0]
fanouts = [int(_) for _ in fan_out]
ladouts = [int(_) for _ in lad_out]
if sampler == "neighbor":
sampler = dgl.dataloading.NeighborSampler(
fanouts,
prefetch_node_feats=["features"],
prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"],
)
elif "ladies" in sampler:
g.edata["w"] = normalized_edata(g)
sampler = (
PoissonLadiesSampler if "poisson" in sampler else LadiesSampler
)(ladouts)
else:
sampler = dgl.dataloading.LaborSampler(
fanouts,
importance_sampling=importance_sampling,
layer_dependency=layer_dependency,
batch_dependency=batch_dependency,
prefetch_node_feats=["features"],
prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"],
)
full_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
len(fanouts),
prefetch_node_feats=["features"],
prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"],
)
unbiased_sampler = sampler
dataloader_device = th.device("cpu")
g = g.formats(["csc"])
if use_uva or not data_cpu:
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
if not data_cpu and not use_uva:
g = g.to(device)
dataloader_device = device
self.g = g
if cast_to_int:
self.train_nid, self.val_nid, self.test_nid = (
train_nid.int(),
val_nid.int(),
test_nid.int(),
)
else:
self.train_nid, self.val_nid, self.test_nid = (
train_nid,
val_nid,
test_nid,
)
self.sampler = sampler
self.unbiased_sampler = unbiased_sampler
self.full_sampler = full_sampler
self.device = dataloader_device
self.use_uva = use_uva
self.batch_size = batch_size
self.num_workers = num_workers
self.in_feats = g.ndata["features"].shape[1]
self.n_classes = n_classes
self.multilabel = multilabel
self.gpu_cache_arg = {"node": {"features": cache_size}}
def train_dataloader(self):
return dgl.dataloading.DataLoader(
self.g,
self.train_nid,
self.sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
def val_dataloader(self):
return [
dgl.dataloading.DataLoader(
self.g,
self.val_nid,
sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
for sampler in [self.unbiased_sampler]
]
def test_dataloader(self):
return [
dgl.dataloading.DataLoader(
self.g,
self.test_nid,
sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
for sampler in [self.full_sampler]
]
class BatchSizeCallback(Callback):
def __init__(self, limit, factor=3):
super().__init__()
self.limit = limit
self.factor = factor
self.clear()
def clear(self):
self.n = 0
self.m = 0
self.s = 0
def push(self, x):
self.n += 1
m = self.m
self.m += (x - m) / self.n
self.s += (x - m) * (x - self.m)
@property
def var(self):
return self.s / (self.n - 1)
@property
def std(self):
return math.sqrt(self.var)
def on_train_batch_start(self, trainer, datamodule, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
features = mfgs[0].srcdata["features"]
if hasattr(features, "__cache_miss__"):
trainer.strategy.model.log(
"cache_miss",
features.__cache_miss__,
prog_bar=True,
on_step=True,
on_epoch=False,
)
def on_train_batch_end(
self, trainer, datamodule, outputs, batch, batch_idx
):
input_nodes, output_nodes, mfgs = batch
self.push(mfgs[0].num_src_nodes())
def on_train_epoch_end(self, trainer, datamodule):
if (
self.limit > 0
and self.n >= 2
and abs(self.limit - self.m) * self.n >= self.std * self.factor
):
trainer.datamodule.batch_size = int(
trainer.datamodule.batch_size * self.limit / self.m
)
trainer.reset_train_dataloader()
trainer.reset_val_dataloader()
self.clear()
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--gpu",
type=int,
default=0,
help="GPU device ID. Use -1 for CPU training",
)
argparser.add_argument("--dataset", type=str, default="reddit")
argparser.add_argument("--num-epochs", type=int, default=-1)
argparser.add_argument("--num-steps", type=int, default=-1)
argparser.add_argument("--min-steps", type=int, default=0)
argparser.add_argument("--num-hidden", type=int, default=256)
argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument("--fan-out", type=str, default="10,10,10")
argparser.add_argument("--lad-out", type=str, default="16000,11000,5000")
argparser.add_argument("--batch-size", type=int, default=1024)
argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument("--independent-batches", type=int, default=1)
argparser.add_argument(
"--num-workers",
type=int,
default=0,
help="Number of sampling processes. Use 0 for no extra process.",
)
argparser.add_argument(
"--data-cpu",
action="store_true",
help="By default the script puts the node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.",
)
argparser.add_argument("--model", type=str, default="sage")
argparser.add_argument("--sampler", type=str, default="labor")
argparser.add_argument("--importance-sampling", type=int, default=0)
argparser.add_argument("--layer-dependency", action="store_true")
argparser.add_argument("--batch-dependency", type=int, default=1)
argparser.add_argument("--logdir", type=str, default="tb_logs")
argparser.add_argument("--vertex-limit", type=int, default=-1)
argparser.add_argument("--use-uva", action="store_true")
argparser.add_argument("--cache-size", type=int, default=0)
argparser.add_argument("--undirected", action="store_true")
argparser.add_argument("--val-acc-target", type=float, default=1)
argparser.add_argument("--early-stopping-patience", type=int, default=10)
argparser.add_argument("--disable-checkpoint", action="store_true")
argparser.add_argument("--precision", type=str, default="highest")
args = argparser.parse_args()
if args.precision != "highest":
th.set_float32_matmul_precision(args.precision)
if args.gpu >= 0:
device = th.device("cuda:%d" % args.gpu)
else:
device = th.device("cpu")
datamodule = DataModule(
args.dataset,
args.undirected,
args.data_cpu,
args.use_uva,
[int(_) for _ in args.fan_out.split(",")],
[int(_) for _ in args.lad_out.split(",")],
device,
args.batch_size // args.independent_batches,
args.num_workers,
args.sampler,
args.importance_sampling,
args.layer_dependency,
args.batch_dependency,
args.cache_size,
)
model = SAGELightning(
datamodule.in_feats,
args.num_hidden,
datamodule.n_classes,
args.num_layers,
args.model,
F.relu,
args.dropout,
args.lr,
datamodule.multilabel,
)
# Train
callbacks = []
if not args.disable_checkpoint:
# callbacks.append(ModelCheckpoint(monitor='val_acc/dataloader_idx_0', save_top_k=1, mode='max'))
callbacks.append(
ModelCheckpoint(monitor="val_acc", save_top_k=1, mode="max")
)
callbacks.append(BatchSizeCallback(args.vertex_limit))
# callbacks.append(EarlyStopping(monitor='val_acc/dataloader_idx_0', stopping_threshold=args.val_acc_target, mode='max', patience=args.early_stopping_patience))
callbacks.append(
EarlyStopping(
monitor="val_acc",
stopping_threshold=args.val_acc_target,
mode="max",
patience=args.early_stopping_patience,
)
)
subdir = "{}_{}_{}_{}_{}_{}".format(
args.dataset,
args.sampler,
args.importance_sampling,
args.layer_dependency,
args.batch_dependency,
args.independent_batches,
)
logger = TensorBoardLogger(args.logdir, name=subdir)
trainer = Trainer(
accelerator="gpu" if args.gpu != -1 else "cpu",
devices=[args.gpu],
accumulate_grad_batches=args.independent_batches,
max_epochs=args.num_epochs,
max_steps=args.num_steps,
min_steps=args.min_steps,
callbacks=callbacks,
logger=logger,
)
trainer.fit(model, datamodule=datamodule)
# Test
if not args.disable_checkpoint:
logdir = os.path.join(args.logdir, subdir)
dirs = glob.glob("./{}/*".format(logdir))
version = max([int(os.path.split(x)[-1].split("_")[-1]) for x in dirs])
logdir = "./{}/version_{}".format(logdir, version)
print("Evaluating model in", logdir)
ckpt = glob.glob(os.path.join(logdir, "checkpoints", "*"))[0]
model = SAGELightning.load_from_checkpoint(
checkpoint_path=ckpt,
hparams_file=os.path.join(logdir, "hparams.yaml"),
).to(device)
test_acc = trainer.test(model, datamodule=datamodule)
print("Test accuracy:", test_acc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment