Unverified Commit 7e4c226d authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Model] Simplify labor example, add proper inference code (#6104)

parent cf829dc3
Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
============
- Paper link: [https://arxiv.org/abs/2210.13339](https://arxiv.org/abs/2210.13339)
This is the official Labor sampling example to reproduce the results in the original
paper with the GraphSAGE GNN model. The model can be changed to any other model where
NeighborSampler can be used.
Requirements
------------
```bash
pip install requests lightning==2.0.6 ogb
```
How to run
-------
### Minibatch training for node classification
Train w/ mini-batch sampling on the GPU for node classification on "ogbn-products"
```bash
python3 train_lightning.py --dataset=ogbn-products
```
Results:
```
Test Accuracy: 0.797
```
Any integer passed as the `--importance-sampling=i` argument runs the corresponding
LABOR-i variant. `--importance-sampling=-1` runs the LABOR-* variant.
`--vertex-limit` argument is used if a vertex sampling budget is needed. It adjusts
the batch size at the end of every epoch so that the average number of sampled vertices
converges to the provided vertex limit. Can be used to replicate the vertex sampling
budget experiments in the Labor paper.
During training runs, statistics about number of sampled vertices, edges,
cache miss rates will be reported. One can use tensorboard to look at their plots
during/after training:
```bash
tensorboard --logdir tb_logs
```
## Utilize a GPU feature cache for UVA training
```bash
python3 train_lightning.py --dataset=ogbn-products --use-uva --cache-size=500000
```
## Reduce GPU feature cache miss rate for UVA training
```bash
python3 train_lightning.py --dataset=ogbn-products --use-uva --cache-size=500000 --batch-dependency=64
```
## Force all layers to share the same neighborhood for shared vertices
```bash
python3 train_lightning.py --dataset=ogbn-products --layer-dependency
```
\ No newline at end of file
......@@ -5,84 +5,7 @@ import sklearn.metrics as skm
import torch as th
import torch.functional as F
import torch.nn as nn
from dgl.nn import GATv2Conv
class GATv2(nn.Module):
def __init__(
self,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
):
super(GATv2, self).__init__()
self.num_layers = num_layers
self.gatv2_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gatv2_layers.append(
GATv2Conv(
in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# hidden layers
for l in range(1, num_layers - 1):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# output projection
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
True,
bias=False,
share_weights=True,
)
)
def forward(self, mfgs, h):
for l, mfg in enumerate(mfgs):
h = self.gatv2_layers[l](mfg, h)
h = h.flatten(1) if l < self.num_layers - 1 else h.mean(1)
return h
import tqdm
class SAGE(nn.Module):
......@@ -124,87 +47,41 @@ class SAGE(nn.Module):
h = self.dropout(h)
return h
class RGAT(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata["h"] = g.ndata["features"]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
1, prefetch_node_feats=["h"]
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
hidden_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels),
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes(), dtype=g.idtype, device=g.device),
sampler,
device=device,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
persistent_workers=(num_workers > 0),
)
self.dropout = nn.Dropout(dropout)
if buffer_device is None:
buffer_device = device
self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
self.train(False)
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[mfg.dst_in_src]
for data in [mfg.srcdata, mfg.dstdata]:
for k in list(data.keys()):
if k not in ["features", "labels"]:
data.pop(k)
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(
mfg.edata["etype"] == j, relabel_nodes=False
)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = th.nn.functional.elu(x)
x = self.dropout(x)
return self.mlp(x)
for l, layer in enumerate(self.layers):
y = th.zeros(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata["h"]
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h.to(buffer_device)
g.ndata["h"] = y
return y
......@@ -33,7 +33,7 @@ import torch.nn.functional as F
from ladies_sampler import LadiesSampler, normalized_edata, PoissonLadiesSampler
from load_graph import load_dataset
from model import GATv2, RGAT, SAGE
from model import SAGE
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
......@@ -41,14 +41,6 @@ from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics.classification import MulticlassF1Score, MultilabelF1Score
def cuda_index_tensor(tensor, idx):
assert idx.device != th.device("cpu")
if tensor.is_pinned():
return dgl.utils.gather_pinned_tensor_rows(tensor, idx)
else:
return tensor[idx.long()]
class SAGELightning(LightningModule):
def __init__(
self,
......@@ -56,7 +48,6 @@ class SAGELightning(LightningModule):
n_hidden,
n_classes,
n_layers,
model,
activation,
dropout,
lr,
......@@ -64,54 +55,15 @@ class SAGELightning(LightningModule):
):
super().__init__()
self.save_hyperparameters()
if model in ["sage"]:
self.module = (
SAGE(
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
if in_feats != 768
else RGAT(
in_feats,
n_classes,
n_hidden,
5,
n_layers,
4,
args.dropout,
"paper",
)
)
else:
heads = ([8] * n_layers) + [1]
self.module = GATv2(
n_layers,
in_feats,
n_hidden,
n_classes,
heads,
activation,
dropout,
dropout,
0.2,
True,
)
self.module = SAGE(
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
self.lr = lr
f1score_class = (
self.f1score_class = lambda: (
MulticlassF1Score if not multilabel else MultilabelF1Score
)
self.train_acc = f1score_class(n_classes, average="micro")
self.val_acc = nn.ModuleList(
[
f1score_class(n_classes, average="micro"),
f1score_class(n_classes, average="micro"),
]
)
self.test_acc = nn.ModuleList(
[
f1score_class(n_classes, average="micro"),
f1score_class(n_classes, average="micro"),
]
)
)(n_classes, average="micro")
self.train_acc = self.f1score_class()
self.val_acc = self.f1score_class()
self.num_steps = 0
self.cum_sampled_nodes = [0 for _ in range(n_layers + 1)]
self.cum_sampled_edges = [0 for _ in range(n_layers)]
......@@ -225,10 +177,10 @@ class SAGELightning(LightningModule):
batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels)
self.val_acc[dataloader_idx](batch_pred, batch_labels.int())
self.val_acc(batch_pred, batch_labels.int())
self.log(
"val_acc",
self.val_acc[dataloader_idx],
self.val_acc,
prog_bar=True,
on_step=False,
on_epoch=True,
......@@ -244,32 +196,6 @@ class SAGELightning(LightningModule):
batch_size=batch_labels.shape[0],
)
def test_step(self, batch, batch_idx, dataloader_idx=0):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels)
self.test_acc[dataloader_idx](batch_pred, batch_labels.int())
self.log(
"test_acc",
self.test_acc[dataloader_idx],
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
self.log(
"test_loss",
loss,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
......@@ -331,13 +257,6 @@ class DataModule(LightningDataModule):
prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"],
)
full_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
len(fanouts),
prefetch_node_feats=["features"],
prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"],
)
unbiased_sampler = sampler
dataloader_device = th.device("cpu")
g = g.formats(["csc"])
......@@ -363,8 +282,6 @@ class DataModule(LightningDataModule):
test_nid,
)
self.sampler = sampler
self.unbiased_sampler = unbiased_sampler
self.full_sampler = full_sampler
self.device = dataloader_device
self.use_uva = use_uva
self.batch_size = batch_size
......@@ -389,38 +306,18 @@ class DataModule(LightningDataModule):
)
def val_dataloader(self):
return [
dgl.dataloading.DataLoader(
self.g,
self.val_nid,
sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
for sampler in [self.unbiased_sampler]
]
def test_dataloader(self):
return [
dgl.dataloading.DataLoader(
self.g,
self.test_nid,
sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
for sampler in [self.full_sampler]
]
return dgl.dataloading.DataLoader(
self.g,
self.val_nid,
self.sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
class BatchSizeCallback(Callback):
......@@ -476,8 +373,10 @@ class BatchSizeCallback(Callback):
trainer.datamodule.batch_size = int(
trainer.datamodule.batch_size * self.limit / self.m
)
trainer.reset_train_dataloader()
trainer.reset_val_dataloader()
loop = trainer._active_loop
assert loop is not None
loop._combined_loader = None
loop.setup_data()
self.clear()
......@@ -500,7 +399,6 @@ if __name__ == "__main__":
argparser.add_argument("--batch-size", type=int, default=1024)
argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument("--independent-batches", type=int, default=1)
argparser.add_argument(
"--num-workers",
type=int,
......@@ -515,8 +413,12 @@ if __name__ == "__main__":
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.",
)
argparser.add_argument("--model", type=str, default="sage")
argparser.add_argument("--sampler", type=str, default="labor")
argparser.add_argument(
"--sampler",
type=str,
default="labor",
choices=["neighbor", "labor", "ladies", "poisson-ladies"],
)
argparser.add_argument("--importance-sampling", type=int, default=0)
argparser.add_argument("--layer-dependency", action="store_true")
argparser.add_argument("--batch-dependency", type=int, default=1)
......@@ -547,7 +449,7 @@ if __name__ == "__main__":
[int(_) for _ in args.fan_out.split(",")],
[int(_) for _ in args.lad_out.split(",")],
device,
args.batch_size // args.independent_batches,
args.batch_size,
args.num_workers,
args.sampler,
args.importance_sampling,
......@@ -560,7 +462,6 @@ if __name__ == "__main__":
args.num_hidden,
datamodule.n_classes,
args.num_layers,
args.model,
F.relu,
args.dropout,
args.lr,
......@@ -570,12 +471,10 @@ if __name__ == "__main__":
# Train
callbacks = []
if not args.disable_checkpoint:
# callbacks.append(ModelCheckpoint(monitor='val_acc/dataloader_idx_0', save_top_k=1, mode='max'))
callbacks.append(
ModelCheckpoint(monitor="val_acc", save_top_k=1, mode="max")
)
callbacks.append(BatchSizeCallback(args.vertex_limit))
# callbacks.append(EarlyStopping(monitor='val_acc/dataloader_idx_0', stopping_threshold=args.val_acc_target, mode='max', patience=args.early_stopping_patience))
callbacks.append(
EarlyStopping(
monitor="val_acc",
......@@ -584,19 +483,17 @@ if __name__ == "__main__":
patience=args.early_stopping_patience,
)
)
subdir = "{}_{}_{}_{}_{}_{}".format(
subdir = "{}_{}_{}_{}_{}".format(
args.dataset,
args.sampler,
args.importance_sampling,
args.layer_dependency,
args.batch_dependency,
args.independent_batches,
)
logger = TensorBoardLogger(args.logdir, name=subdir)
trainer = Trainer(
accelerator="gpu" if args.gpu != -1 else "cpu",
devices=[args.gpu],
accumulate_grad_batches=args.independent_batches,
max_epochs=args.num_epochs,
max_steps=args.num_steps,
min_steps=args.min_steps,
......@@ -618,5 +515,21 @@ if __name__ == "__main__":
checkpoint_path=ckpt,
hparams_file=os.path.join(logdir, "hparams.yaml"),
).to(device)
test_acc = trainer.test(model, datamodule=datamodule)
print("Test accuracy:", test_acc)
with th.no_grad():
graph = datamodule.g
pred = model.module.inference(
graph,
f"cuda:{args.gpu}" if args.gpu != -1 else "cpu",
4096,
args.num_workers,
graph.device,
)
for nid, split_name in zip(
[datamodule.train_nid, datamodule.val_nid, datamodule.test_nid],
["Train", "Validation", "Test"],
):
pred_nid = pred[nid]
label = graph.ndata["labels"][nid]
f1score = model.f1score_class().to(pred.device)
acc = f1score(pred_nid, label)
print(f"{split_name} accuracy:", acc.item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment