Unverified Commit c0ac2f60 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Improving the node prediction example (#6353)

parent 6a9142c2
...@@ -9,5 +9,5 @@ python3 node_classification.py ...@@ -9,5 +9,5 @@ python3 node_classification.py
### Results ### Results
``` ```
Valid Accuracy: 0.877 Valid Accuracy: 0.907
``` ```
\ No newline at end of file
...@@ -10,7 +10,8 @@ main ...@@ -10,7 +10,8 @@ main
│ │ │ │
│ └───> ItemSampler (Distribute data to minibatchs) │ └───> ItemSampler (Distribute data to minibatchs)
│ │ │ │
│ └───> sample_neighbor (Sample a subgraph for a minibatch) │ └───> sample_neighbor or sample_layer_neighbor
(Sample a subgraph for a minibatch)
│ │ │ │
│ └───> fetch_feature (Fetch features for the sampled subgraph) │ └───> fetch_feature (Fetch features for the sampled subgraph)
...@@ -29,7 +30,7 @@ main ...@@ -29,7 +30,7 @@ main
└───> Trainer[HIGHLIGHT] └───> Trainer[HIGHLIGHT]
├───> SAGE.forward (RGCN model forward pass) ├───> SAGE.forward (GraphSAGE model forward pass)
└───> Validate └───> Validate
""" """
...@@ -42,7 +43,7 @@ import torch ...@@ -42,7 +43,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchmetrics import Accuracy from torchmetrics import Accuracy
...@@ -69,6 +70,28 @@ class SAGE(LightningModule): ...@@ -69,6 +70,28 @@ class SAGE(LightningModule):
h = self.dropout(h) h = self.dropout(h)
return h return h
def log_node_and_edge_counts(self, blocks):
node_counts = [block.num_src_nodes() for block in blocks] + [
blocks[-1].num_dst_nodes()
]
edge_counts = [block.num_edges() for block in blocks]
for i, c in enumerate(node_counts):
self.log(
f"num_nodes/{i}",
float(c),
prog_bar=True,
on_step=True,
on_epoch=False,
)
if i < len(edge_counts):
self.log(
f"num_edges/{i}",
float(edge_counts[i]),
prog_bar=True,
on_step=True,
on_epoch=False,
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# TODO: Move this to the data pipeline as a stage. # TODO: Move this to the data pipeline as a stage.
blocks = [block.to("cuda") for block in batch.to_dgl_blocks()] blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
...@@ -84,6 +107,7 @@ class SAGE(LightningModule): ...@@ -84,6 +107,7 @@ class SAGE(LightningModule):
on_step=True, on_step=True,
on_epoch=False, on_epoch=False,
) )
self.log_node_and_edge_counts(blocks)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
...@@ -96,10 +120,11 @@ class SAGE(LightningModule): ...@@ -96,10 +120,11 @@ class SAGE(LightningModule):
"val_acc", "val_acc",
self.val_acc, self.val_acc,
prog_bar=True, prog_bar=True,
on_step=True, on_step=False,
on_epoch=True, on_epoch=True,
sync_dist=True, sync_dist=True,
) )
self.log_node_and_edge_counts(blocks)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
...@@ -114,75 +139,67 @@ class DataModule(LightningDataModule): ...@@ -114,75 +139,67 @@ class DataModule(LightningDataModule):
self.fanouts = fanouts self.fanouts = fanouts
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
# TODO: Update with a publicly accessible URL once the dataset has been dataset = gb.BuiltinDataset("ogbn-products").load()
# uploaded.
dataset = gb.OnDiskDataset(
"/home/ubuntu/workspace/example_ogbn_products/"
)
dataset.load()
self.feature_store = dataset.feature self.feature_store = dataset.feature
self.graph = dataset.graph self.graph = dataset.graph
self.train_set = dataset.tasks[0].train_set self.train_set = dataset.tasks[0].train_set
self.valid_set = dataset.tasks[0].validation_set self.valid_set = dataset.tasks[0].validation_set
self.num_classes = dataset.tasks[0].metadata["num_classes"] self.num_classes = dataset.tasks[0].metadata["num_classes"]
######################################################################## def create_dataloader(self, node_set, is_train):
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datappipes, for these purposes.
########################################################################
def train_dataloader(self):
datapipe = gb.ItemSampler( datapipe = gb.ItemSampler(
self.train_set, node_set,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
) )
datapipe = datapipe.sample_neighbor(self.graph, self.fanouts) sampler = (
datapipe.sample_layer_neighbor
if is_train
else datapipe.sample_neighbor
)
datapipe = sampler(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"]) datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
dataloader = gb.MultiProcessDataLoader( dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers datapipe, num_workers=self.num_workers
) )
return dataloader return dataloader
########################################################################
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datapipes, for these purposes.
########################################################################
def train_dataloader(self):
return self.create_dataloader(self.train_set, is_train=True)
def val_dataloader(self): def val_dataloader(self):
datapipe = gb.ItemSampler( return self.create_dataloader(self.valid_set, is_train=False)
self.valid_set,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
)
datapipe = datapipe.sample_neighbor(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers
)
return dataloader
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="GNN baselines on ogbgmol* data with Pytorch Geometrics" description="GNN baselines on ogbn-products data with GraphBolt"
) )
parser.add_argument( parser.add_argument(
"--num_gpus", "--num_gpus",
type=int, type=int,
default=4, default=1,
help="number of GPUs used for computing (default: 4)", help="number of GPUs used for computing (default: 1)",
) )
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
default=32, default=1024,
help="input batch size for training (default: 32)", help="input batch size for training (default: 1024)",
) )
parser.add_argument( parser.add_argument(
"--epochs", "--epochs",
type=int, type=int,
default=10, default=40,
help="number of epochs to train (default: 100)", help="number of epochs to train (default: 40)",
) )
parser.add_argument( parser.add_argument(
"--num_workers", "--num_workers",
...@@ -196,7 +213,8 @@ if __name__ == "__main__": ...@@ -196,7 +213,8 @@ if __name__ == "__main__":
model = SAGE(100, 256, datamodule.num_classes).to(torch.double) model = SAGE(100, 256, datamodule.num_classes).to(torch.double)
# Train. # Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1) checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max")
######################################################################## ########################################################################
# (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates # (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
# everything after defining `LightningDataModule` and # everything after defining `LightningDataModule` and
...@@ -207,6 +225,6 @@ if __name__ == "__main__": ...@@ -207,6 +225,6 @@ if __name__ == "__main__":
accelerator="gpu", accelerator="gpu",
devices=args.num_gpus, devices=args.num_gpus,
max_epochs=args.epochs, max_epochs=args.epochs,
callbacks=[checkpoint_callback], callbacks=[checkpoint_callback, early_stopping_callback],
) )
trainer.fit(model, datamodule=datamodule) trainer.fit(model, datamodule=datamodule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment