"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f6262182069ecbea5fb05cbad207864fa7708427"
Unverified Commit adf49937 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add lightning example (#6335)

parent cbc34705
# Node classification on homogeneous graph with GraphSAGE
## Run on `ogbn-products` dataset
### Command
```
python3 node_classification.py
```
### Results
```
Valid Accuracy: 0.877
```
\ No newline at end of file
"""
This flowchart describes the main functional sequence of the provided example.
main
├───> Instantiate DataModule
│ │
│ └───> Load dataset
│ │
│ └───> Create train and valid dataloader[HIGHLIGHT]
│ │
│ └───> ItemSampler (Distribute data to minibatchs)
│ │
│ └───> sample_neighbor (Sample a subgraph for a minibatch)
│ │
│ └───> fetch_feature (Fetch features for the sampled subgraph)
├───> Instantiate GraphSAGE model
│ │
│ ├───> SAGEConvLayer (input to hidden)
│ │
│ └───> SAGEConvLayer (hidden to hidden)
│ │
│ └───> SAGEConvLayer (hidden to output)
│ │
│ └───> DropoutLayer
└───> Run
└───> Trainer[HIGHLIGHT]
├───> SAGE.forward (RGCN model forward pass)
└───> Validate
"""
import argparse
import dgl.graphbolt as gb
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
class SAGE(LightningModule):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
self.save_hyperparameters()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes
self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def training_step(self, batch, batch_idx):
# TODO: Move this to the data pipeline as a stage.
blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
x = blocks[0].srcdata["feat"]
y = batch.labels.to("cuda")
y_hat = self(blocks, x)
loss = F.cross_entropy(y_hat, y)
self.train_acc(torch.argmax(y_hat, 1), y)
self.log(
"train_acc",
self.train_acc,
prog_bar=True,
on_step=True,
on_epoch=False,
)
return loss
def validation_step(self, batch, batch_idx):
blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
x = blocks[0].srcdata["feat"]
y = batch.labels.to("cuda")
y_hat = self(blocks, x)
self.val_acc(torch.argmax(y_hat, 1), y)
self.log(
"val_acc",
self.val_acc,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=0.001, weight_decay=5e-4
)
return optimizer
class DataModule(LightningDataModule):
def __init__(self, fanouts, batch_size, num_workers):
super().__init__()
self.fanouts = fanouts
self.batch_size = batch_size
self.num_workers = num_workers
# TODO: Update with a publicly accessible URL once the dataset has been
# uploaded.
dataset = gb.OnDiskDataset(
"/home/ubuntu/workspace/example_ogbn_products/"
)
dataset.load()
self.feature_store = dataset.feature
self.graph = dataset.graph
self.train_set = dataset.tasks[0].train_set
self.valid_set = dataset.tasks[0].validation_set
self.num_classes = dataset.tasks[0].metadata["num_classes"]
########################################################################
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datappipes, for these purposes.
########################################################################
def train_dataloader(self):
datapipe = gb.ItemSampler(
self.train_set,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
)
datapipe = datapipe.sample_neighbor(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers
)
return dataloader
def val_dataloader(self):
datapipe = gb.ItemSampler(
self.valid_set,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
)
datapipe = datapipe.sample_neighbor(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers
)
return dataloader
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="GNN baselines on ogbgmol* data with Pytorch Geometrics"
)
parser.add_argument(
"--num_gpus",
type=int,
default=4,
help="number of GPUs used for computing (default: 4)",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="input batch size for training (default: 32)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="number of epochs to train (default: 100)",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="number of workers (default: 0)",
)
args = parser.parse_args()
datamodule = DataModule([15, 10, 5], args.batch_size, args.num_workers)
model = SAGE(100, 256, datamodule.num_classes).to(torch.double)
# Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
########################################################################
# (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
# everything after defining `LightningDataModule` and
# `LightningDataModule`. More details can be found in
# https://lightning.ai/docs/pytorch/stable/common/trainer.html.
########################################################################
trainer = Trainer(
accelerator="gpu",
devices=args.num_gpus,
max_epochs=args.epochs,
callbacks=[checkpoint_callback],
)
trainer.fit(model, datamodule=datamodule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment