"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "325a5de3a9acc97534a4446ce9dd4147efcd61a0"
Unverified Commit 7e4c226d authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Model] Simplify labor example, add proper inference code (#6104)

parent cf829dc3
Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
============
- Paper link: [https://arxiv.org/abs/2210.13339](https://arxiv.org/abs/2210.13339)
This is the official Labor sampling example to reproduce the results in the original
paper with the GraphSAGE GNN model. The model can be changed to any other model where
NeighborSampler can be used.
Requirements
------------
```bash
pip install requests lightning==2.0.6 ogb
```
How to run
-------
### Minibatch training for node classification
Train w/ mini-batch sampling on the GPU for node classification on "ogbn-products"
```bash
python3 train_lightning.py --dataset=ogbn-products
```
Results:
```
Test Accuracy: 0.797
```
Any integer passed as the `--importance-sampling=i` argument runs the corresponding
LABOR-i variant. `--importance-sampling=-1` runs the LABOR-* variant.
`--vertex-limit` argument is used if a vertex sampling budget is needed. It adjusts
the batch size at the end of every epoch so that the average number of sampled vertices
converges to the provided vertex limit. Can be used to replicate the vertex sampling
budget experiments in the Labor paper.
During training runs, statistics about number of sampled vertices, edges,
cache miss rates will be reported. One can use tensorboard to look at their plots
during/after training:
```bash
tensorboard --logdir tb_logs
```
## Utilize a GPU feature cache for UVA training
```bash
python3 train_lightning.py --dataset=ogbn-products --use-uva --cache-size=500000
```
## Reduce GPU feature cache miss rate for UVA training
```bash
python3 train_lightning.py --dataset=ogbn-products --use-uva --cache-size=500000 --batch-dependency=64
```
## Force all layers to share the same neighborhood for shared vertices
```bash
python3 train_lightning.py --dataset=ogbn-products --layer-dependency
```
\ No newline at end of file
...@@ -5,84 +5,7 @@ import sklearn.metrics as skm ...@@ -5,84 +5,7 @@ import sklearn.metrics as skm
import torch as th import torch as th
import torch.functional as F import torch.functional as F
import torch.nn as nn import torch.nn as nn
import tqdm
from dgl.nn import GATv2Conv
class GATv2(nn.Module):
def __init__(
self,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
):
super(GATv2, self).__init__()
self.num_layers = num_layers
self.gatv2_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gatv2_layers.append(
GATv2Conv(
in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# hidden layers
for l in range(1, num_layers - 1):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
True,
bias=False,
share_weights=True,
)
)
# output projection
self.gatv2_layers.append(
GATv2Conv(
num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
True,
bias=False,
share_weights=True,
)
)
def forward(self, mfgs, h):
for l, mfg in enumerate(mfgs):
h = self.gatv2_layers[l](mfg, h)
h = h.flatten(1) if l < self.num_layers - 1 else h.mean(1)
return h
class SAGE(nn.Module): class SAGE(nn.Module):
...@@ -124,87 +47,41 @@ class SAGE(nn.Module): ...@@ -124,87 +47,41 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
class RGAT(nn.Module): # The difference between this inference function and the one in the official
def __init__( # example is that the intermediate results can also benefit from prefetching.
self, g.ndata["h"] = g.ndata["features"]
in_channels, sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
out_channels, 1, prefetch_node_feats=["h"]
hidden_channels,
num_etypes,
num_layers,
num_heads,
dropout,
pred_ntype,
):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(
nn.ModuleList(
[
dglnn.GATConv(
in_channels,
hidden_channels // num_heads,
num_heads,
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
) )
self.norms.append(nn.BatchNorm1d(hidden_channels)) dataloader = dgl.dataloading.DataLoader(
self.skips.append(nn.Linear(in_channels, hidden_channels)) g,
for _ in range(num_layers - 1): th.arange(g.num_nodes(), dtype=g.idtype, device=g.device),
self.convs.append( sampler,
nn.ModuleList( device=device,
[ batch_size=batch_size,
dglnn.GATConv( shuffle=False,
hidden_channels, drop_last=False,
hidden_channels // num_heads, num_workers=num_workers,
num_heads, persistent_workers=(num_workers > 0),
allow_zero_in_degree=True,
)
for _ in range(num_etypes)
]
)
)
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels),
) )
self.dropout = nn.Dropout(dropout) if buffer_device is None:
buffer_device = device
self.hidden_channels = hidden_channels self.train(False)
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
def forward(self, mfgs, x): for l, layer in enumerate(self.layers):
for i in range(len(mfgs)): y = th.zeros(
mfg = mfgs[i] g.num_nodes(),
x_dst = x[mfg.dst_in_src] self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
for data in [mfg.srcdata, mfg.dstdata]: device=buffer_device,
for k in list(data.keys()): )
if k not in ["features", "labels"]: for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
data.pop(k) x = blocks[0].srcdata["h"]
mfg = dgl.block_to_graph(mfg) h = layer(blocks[0], x)
x_skip = self.skips[i](x_dst) if l != len(self.layers) - 1:
for j in range(self.num_etypes): h = self.activation(h)
subg = mfg.edge_subgraph( h = self.dropout(h)
mfg.edata["etype"] == j, relabel_nodes=False y[output_nodes] = h.to(buffer_device)
) g.ndata["h"] = y
x_skip += self.convs[i][j](subg, (x, x_dst)).view( return y
-1, self.hidden_channels
)
x = self.norms[i](x_skip)
x = th.nn.functional.elu(x)
x = self.dropout(x)
return self.mlp(x)
...@@ -33,7 +33,7 @@ import torch.nn.functional as F ...@@ -33,7 +33,7 @@ import torch.nn.functional as F
from ladies_sampler import LadiesSampler, normalized_edata, PoissonLadiesSampler from ladies_sampler import LadiesSampler, normalized_edata, PoissonLadiesSampler
from load_graph import load_dataset from load_graph import load_dataset
from model import GATv2, RGAT, SAGE from model import SAGE
from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
...@@ -41,14 +41,6 @@ from pytorch_lightning.loggers import TensorBoardLogger ...@@ -41,14 +41,6 @@ from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics.classification import MulticlassF1Score, MultilabelF1Score from torchmetrics.classification import MulticlassF1Score, MultilabelF1Score
def cuda_index_tensor(tensor, idx):
assert idx.device != th.device("cpu")
if tensor.is_pinned():
return dgl.utils.gather_pinned_tensor_rows(tensor, idx)
else:
return tensor[idx.long()]
class SAGELightning(LightningModule): class SAGELightning(LightningModule):
def __init__( def __init__(
self, self,
...@@ -56,7 +48,6 @@ class SAGELightning(LightningModule): ...@@ -56,7 +48,6 @@ class SAGELightning(LightningModule):
n_hidden, n_hidden,
n_classes, n_classes,
n_layers, n_layers,
model,
activation, activation,
dropout, dropout,
lr, lr,
...@@ -64,54 +55,15 @@ class SAGELightning(LightningModule): ...@@ -64,54 +55,15 @@ class SAGELightning(LightningModule):
): ):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
if model in ["sage"]: self.module = SAGE(
self.module = ( in_feats, n_hidden, n_classes, n_layers, activation, dropout
SAGE( )
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
if in_feats != 768
else RGAT(
in_feats,
n_classes,
n_hidden,
5,
n_layers,
4,
args.dropout,
"paper",
)
)
else:
heads = ([8] * n_layers) + [1]
self.module = GATv2(
n_layers,
in_feats,
n_hidden,
n_classes,
heads,
activation,
dropout,
dropout,
0.2,
True,
)
self.lr = lr self.lr = lr
f1score_class = ( self.f1score_class = lambda: (
MulticlassF1Score if not multilabel else MultilabelF1Score MulticlassF1Score if not multilabel else MultilabelF1Score
) )(n_classes, average="micro")
self.train_acc = f1score_class(n_classes, average="micro") self.train_acc = self.f1score_class()
self.val_acc = nn.ModuleList( self.val_acc = self.f1score_class()
[
f1score_class(n_classes, average="micro"),
f1score_class(n_classes, average="micro"),
]
)
self.test_acc = nn.ModuleList(
[
f1score_class(n_classes, average="micro"),
f1score_class(n_classes, average="micro"),
]
)
self.num_steps = 0 self.num_steps = 0
self.cum_sampled_nodes = [0 for _ in range(n_layers + 1)] self.cum_sampled_nodes = [0 for _ in range(n_layers + 1)]
self.cum_sampled_edges = [0 for _ in range(n_layers)] self.cum_sampled_edges = [0 for _ in range(n_layers)]
...@@ -225,10 +177,10 @@ class SAGELightning(LightningModule): ...@@ -225,10 +177,10 @@ class SAGELightning(LightningModule):
batch_labels = mfgs[-1].dstdata["labels"] batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs) batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels) loss = self.loss_fn(batch_pred, batch_labels)
self.val_acc[dataloader_idx](batch_pred, batch_labels.int()) self.val_acc(batch_pred, batch_labels.int())
self.log( self.log(
"val_acc", "val_acc",
self.val_acc[dataloader_idx], self.val_acc,
prog_bar=True, prog_bar=True,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
...@@ -244,32 +196,6 @@ class SAGELightning(LightningModule): ...@@ -244,32 +196,6 @@ class SAGELightning(LightningModule):
batch_size=batch_labels.shape[0], batch_size=batch_labels.shape[0],
) )
def test_step(self, batch, batch_idx, dataloader_idx=0):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fn(batch_pred, batch_labels)
self.test_acc[dataloader_idx](batch_pred, batch_labels.int())
self.log(
"test_acc",
self.test_acc[dataloader_idx],
prog_bar=True,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
self.log(
"test_loss",
loss,
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_labels.shape[0],
)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=self.lr) optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
return optimizer return optimizer
...@@ -331,13 +257,6 @@ class DataModule(LightningDataModule): ...@@ -331,13 +257,6 @@ class DataModule(LightningDataModule):
prefetch_edge_feats=["etype"] if "etype" in g.edata else [], prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"], prefetch_labels=["labels"],
) )
full_sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
len(fanouts),
prefetch_node_feats=["features"],
prefetch_edge_feats=["etype"] if "etype" in g.edata else [],
prefetch_labels=["labels"],
)
unbiased_sampler = sampler
dataloader_device = th.device("cpu") dataloader_device = th.device("cpu")
g = g.formats(["csc"]) g = g.formats(["csc"])
...@@ -363,8 +282,6 @@ class DataModule(LightningDataModule): ...@@ -363,8 +282,6 @@ class DataModule(LightningDataModule):
test_nid, test_nid,
) )
self.sampler = sampler self.sampler = sampler
self.unbiased_sampler = unbiased_sampler
self.full_sampler = full_sampler
self.device = dataloader_device self.device = dataloader_device
self.use_uva = use_uva self.use_uva = use_uva
self.batch_size = batch_size self.batch_size = batch_size
...@@ -389,38 +306,18 @@ class DataModule(LightningDataModule): ...@@ -389,38 +306,18 @@ class DataModule(LightningDataModule):
) )
def val_dataloader(self): def val_dataloader(self):
return [ return dgl.dataloading.DataLoader(
dgl.dataloading.DataLoader( self.g,
self.g, self.val_nid,
self.val_nid, self.sampler,
sampler, device=self.device,
device=self.device, use_uva=self.use_uva,
use_uva=self.use_uva, batch_size=self.batch_size,
batch_size=self.batch_size, shuffle=False,
shuffle=False, drop_last=False,
drop_last=False, num_workers=self.num_workers,
num_workers=self.num_workers, gpu_cache=self.gpu_cache_arg,
gpu_cache=self.gpu_cache_arg, )
)
for sampler in [self.unbiased_sampler]
]
def test_dataloader(self):
return [
dgl.dataloading.DataLoader(
self.g,
self.test_nid,
sampler,
device=self.device,
use_uva=self.use_uva,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers,
gpu_cache=self.gpu_cache_arg,
)
for sampler in [self.full_sampler]
]
class BatchSizeCallback(Callback): class BatchSizeCallback(Callback):
...@@ -476,8 +373,10 @@ class BatchSizeCallback(Callback): ...@@ -476,8 +373,10 @@ class BatchSizeCallback(Callback):
trainer.datamodule.batch_size = int( trainer.datamodule.batch_size = int(
trainer.datamodule.batch_size * self.limit / self.m trainer.datamodule.batch_size * self.limit / self.m
) )
trainer.reset_train_dataloader() loop = trainer._active_loop
trainer.reset_val_dataloader() assert loop is not None
loop._combined_loader = None
loop.setup_data()
self.clear() self.clear()
...@@ -500,7 +399,6 @@ if __name__ == "__main__": ...@@ -500,7 +399,6 @@ if __name__ == "__main__":
argparser.add_argument("--batch-size", type=int, default=1024) argparser.add_argument("--batch-size", type=int, default=1024)
argparser.add_argument("--lr", type=float, default=0.001) argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument("--dropout", type=float, default=0.5) argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument("--independent-batches", type=int, default=1)
argparser.add_argument( argparser.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
...@@ -515,8 +413,12 @@ if __name__ == "__main__": ...@@ -515,8 +413,12 @@ if __name__ == "__main__":
"be undesired if they cannot fit in GPU memory at once. " "be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.", "This flag disables that.",
) )
argparser.add_argument("--model", type=str, default="sage") argparser.add_argument(
argparser.add_argument("--sampler", type=str, default="labor") "--sampler",
type=str,
default="labor",
choices=["neighbor", "labor", "ladies", "poisson-ladies"],
)
argparser.add_argument("--importance-sampling", type=int, default=0) argparser.add_argument("--importance-sampling", type=int, default=0)
argparser.add_argument("--layer-dependency", action="store_true") argparser.add_argument("--layer-dependency", action="store_true")
argparser.add_argument("--batch-dependency", type=int, default=1) argparser.add_argument("--batch-dependency", type=int, default=1)
...@@ -547,7 +449,7 @@ if __name__ == "__main__": ...@@ -547,7 +449,7 @@ if __name__ == "__main__":
[int(_) for _ in args.fan_out.split(",")], [int(_) for _ in args.fan_out.split(",")],
[int(_) for _ in args.lad_out.split(",")], [int(_) for _ in args.lad_out.split(",")],
device, device,
args.batch_size // args.independent_batches, args.batch_size,
args.num_workers, args.num_workers,
args.sampler, args.sampler,
args.importance_sampling, args.importance_sampling,
...@@ -560,7 +462,6 @@ if __name__ == "__main__": ...@@ -560,7 +462,6 @@ if __name__ == "__main__":
args.num_hidden, args.num_hidden,
datamodule.n_classes, datamodule.n_classes,
args.num_layers, args.num_layers,
args.model,
F.relu, F.relu,
args.dropout, args.dropout,
args.lr, args.lr,
...@@ -570,12 +471,10 @@ if __name__ == "__main__": ...@@ -570,12 +471,10 @@ if __name__ == "__main__":
# Train # Train
callbacks = [] callbacks = []
if not args.disable_checkpoint: if not args.disable_checkpoint:
# callbacks.append(ModelCheckpoint(monitor='val_acc/dataloader_idx_0', save_top_k=1, mode='max'))
callbacks.append( callbacks.append(
ModelCheckpoint(monitor="val_acc", save_top_k=1, mode="max") ModelCheckpoint(monitor="val_acc", save_top_k=1, mode="max")
) )
callbacks.append(BatchSizeCallback(args.vertex_limit)) callbacks.append(BatchSizeCallback(args.vertex_limit))
# callbacks.append(EarlyStopping(monitor='val_acc/dataloader_idx_0', stopping_threshold=args.val_acc_target, mode='max', patience=args.early_stopping_patience))
callbacks.append( callbacks.append(
EarlyStopping( EarlyStopping(
monitor="val_acc", monitor="val_acc",
...@@ -584,19 +483,17 @@ if __name__ == "__main__": ...@@ -584,19 +483,17 @@ if __name__ == "__main__":
patience=args.early_stopping_patience, patience=args.early_stopping_patience,
) )
) )
subdir = "{}_{}_{}_{}_{}_{}".format( subdir = "{}_{}_{}_{}_{}".format(
args.dataset, args.dataset,
args.sampler, args.sampler,
args.importance_sampling, args.importance_sampling,
args.layer_dependency, args.layer_dependency,
args.batch_dependency, args.batch_dependency,
args.independent_batches,
) )
logger = TensorBoardLogger(args.logdir, name=subdir) logger = TensorBoardLogger(args.logdir, name=subdir)
trainer = Trainer( trainer = Trainer(
accelerator="gpu" if args.gpu != -1 else "cpu", accelerator="gpu" if args.gpu != -1 else "cpu",
devices=[args.gpu], devices=[args.gpu],
accumulate_grad_batches=args.independent_batches,
max_epochs=args.num_epochs, max_epochs=args.num_epochs,
max_steps=args.num_steps, max_steps=args.num_steps,
min_steps=args.min_steps, min_steps=args.min_steps,
...@@ -618,5 +515,21 @@ if __name__ == "__main__": ...@@ -618,5 +515,21 @@ if __name__ == "__main__":
checkpoint_path=ckpt, checkpoint_path=ckpt,
hparams_file=os.path.join(logdir, "hparams.yaml"), hparams_file=os.path.join(logdir, "hparams.yaml"),
).to(device) ).to(device)
test_acc = trainer.test(model, datamodule=datamodule) with th.no_grad():
print("Test accuracy:", test_acc) graph = datamodule.g
pred = model.module.inference(
graph,
f"cuda:{args.gpu}" if args.gpu != -1 else "cpu",
4096,
args.num_workers,
graph.device,
)
for nid, split_name in zip(
[datamodule.train_nid, datamodule.val_nid, datamodule.test_nid],
["Train", "Validation", "Test"],
):
pred_nid = pred[nid]
label = graph.ndata["labels"][nid]
f1score = model.f1score_class().to(pred.device)
acc = f1score(pred_nid, label)
print(f"{split_name} accuracy:", acc.item())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment