Unverified Commit c0ac2f60 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Improving the node prediction example (#6353)

parent 6a9142c2
......@@ -9,5 +9,5 @@ python3 node_classification.py
### Results
```
Valid Accuracy: 0.877
Valid Accuracy: 0.907
```
\ No newline at end of file
......@@ -10,7 +10,8 @@ main
│ │
│ └───> ItemSampler (Distribute data to minibatchs)
│ │
│ └───> sample_neighbor (Sample a subgraph for a minibatch)
│ └───> sample_neighbor or sample_layer_neighbor
(Sample a subgraph for a minibatch)
│ │
│ └───> fetch_feature (Fetch features for the sampled subgraph)
......@@ -29,7 +30,7 @@ main
└───> Trainer[HIGHLIGHT]
├───> SAGE.forward (RGCN model forward pass)
├───> SAGE.forward (GraphSAGE model forward pass)
└───> Validate
"""
......@@ -42,7 +43,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchmetrics import Accuracy
......@@ -69,6 +70,28 @@ class SAGE(LightningModule):
h = self.dropout(h)
return h
def log_node_and_edge_counts(self, blocks):
node_counts = [block.num_src_nodes() for block in blocks] + [
blocks[-1].num_dst_nodes()
]
edge_counts = [block.num_edges() for block in blocks]
for i, c in enumerate(node_counts):
self.log(
f"num_nodes/{i}",
float(c),
prog_bar=True,
on_step=True,
on_epoch=False,
)
if i < len(edge_counts):
self.log(
f"num_edges/{i}",
float(edge_counts[i]),
prog_bar=True,
on_step=True,
on_epoch=False,
)
def training_step(self, batch, batch_idx):
# TODO: Move this to the data pipeline as a stage.
blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
......@@ -84,6 +107,7 @@ class SAGE(LightningModule):
on_step=True,
on_epoch=False,
)
self.log_node_and_edge_counts(blocks)
return loss
def validation_step(self, batch, batch_idx):
......@@ -96,10 +120,11 @@ class SAGE(LightningModule):
"val_acc",
self.val_acc,
prog_bar=True,
on_step=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log_node_and_edge_counts(blocks)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
......@@ -114,75 +139,67 @@ class DataModule(LightningDataModule):
self.fanouts = fanouts
self.batch_size = batch_size
self.num_workers = num_workers
# TODO: Update with a publicly accessible URL once the dataset has been
# uploaded.
dataset = gb.OnDiskDataset(
"/home/ubuntu/workspace/example_ogbn_products/"
)
dataset.load()
dataset = gb.BuiltinDataset("ogbn-products").load()
self.feature_store = dataset.feature
self.graph = dataset.graph
self.train_set = dataset.tasks[0].train_set
self.valid_set = dataset.tasks[0].validation_set
self.num_classes = dataset.tasks[0].metadata["num_classes"]
########################################################################
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datappipes, for these purposes.
########################################################################
def train_dataloader(self):
def create_dataloader(self, node_set, is_train):
datapipe = gb.ItemSampler(
self.train_set,
node_set,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
)
datapipe = datapipe.sample_neighbor(self.graph, self.fanouts)
sampler = (
datapipe.sample_layer_neighbor
if is_train
else datapipe.sample_neighbor
)
datapipe = sampler(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers
)
return dataloader
########################################################################
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datapipes, for these purposes.
########################################################################
def train_dataloader(self):
return self.create_dataloader(self.train_set, is_train=True)
def val_dataloader(self):
datapipe = gb.ItemSampler(
self.valid_set,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
)
datapipe = datapipe.sample_neighbor(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers
)
return dataloader
return self.create_dataloader(self.valid_set, is_train=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="GNN baselines on ogbgmol* data with Pytorch Geometrics"
description="GNN baselines on ogbn-products data with GraphBolt"
)
parser.add_argument(
"--num_gpus",
type=int,
default=4,
help="number of GPUs used for computing (default: 4)",
default=1,
help="number of GPUs used for computing (default: 1)",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="input batch size for training (default: 32)",
default=1024,
help="input batch size for training (default: 1024)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="number of epochs to train (default: 100)",
default=40,
help="number of epochs to train (default: 40)",
)
parser.add_argument(
"--num_workers",
......@@ -196,7 +213,8 @@ if __name__ == "__main__":
model = SAGE(100, 256, datamodule.num_classes).to(torch.double)
# Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max")
########################################################################
# (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
# everything after defining `LightningDataModule` and
......@@ -207,6 +225,6 @@ if __name__ == "__main__":
accelerator="gpu",
devices=args.num_gpus,
max_epochs=args.epochs,
callbacks=[checkpoint_callback],
callbacks=[checkpoint_callback, early_stopping_callback],
)
trainer.fit(model, datamodule=datamodule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment