"vscode:/vscode.git/clone" did not exist on "f032e56f8af4d40bdbb74d15fa9c5cd63f8b7a40"
Unverified Commit f8594230 authored by MB Wang's avatar MB Wang Committed by GitHub
Browse files

[GraphBolt] GraphBolt Node Classification Example (#6359)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
parent 0a87bc6a
"""
This script trains and tests a GraphSAGE model for node classification
on large graphs using GraphBolt dataloader.
Paper: [Inductive Representation Learning on Large Graphs]
(https://arxiv.org/abs/1706.02216)
Unlike previous dgl examples, we've utilized the newly defined dataloader
from GraphBolt. This example will help you grasp how to build an end-to-end
training pipeline using GraphBolt.
Before reading this example, please familiar yourself with graphsage node
classification by reading the example in the
`examples/core/graphsage/node_classification.py`. This introduction,
[A Blitz Introduction to Node Classification with DGL]
(https://docs.dgl.ai/tutorials/blitz/1_introduction.html), might be helpful.
If you want to train graphsage on a large graph in a distributed fashion,
please read the example in the `examples/distributed/graphsage/`.
This flowchart describes the main functional sequence of the provided example:
main
├───> OnDiskDataset pre-processing
├───> Instantiate SAGE model
├───> train
│ │
│ ├───> Get graphbolt dataloader (HIGHLIGHT)
│ │
│ └───> Training loop
│ │
│ ├───> SAGE.forward
│ │
│ └───> Validation set evaluation
└───> Test set evaluation
"""
import argparse
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float64)
def set_layer_dtype(self, dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(dtype)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def create_dataloader(args, graph, features, itemset, is_train=True):
"""
[HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks.
This function demonstrates how to utilize functional forms of datapipes in
GraphBolt.
Alternatively, you can create a datapipe using its class constructor.
"""
############################################################################
# [Step-1]:
# gb.ItemSampler()
# [Input]:
# 'itemset': The current dataset. (e.g. `train_set` or `valid_set`)
# 'args.batch_size': Specify the number of samples to be processed together,
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.)
# 'is_train': Determining if data should be shuffled. (Shuffling is
# generally used only in training to improve model generalization. It's
# not used in validation and testing as the focus there is to evaluate
# performance rather than to learn from the data.)
# [Output]:
# An ItemSampler object for handling mini-batch sampling.
# [Role]:
# Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################
datapipe = gb.ItemSampler(
itemset, batch_size=args.batch_size, shuffle=is_train
)
############################################################################
# [Step-2]:
# self.sample_neighbor()
# [Input]:
# 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending
# on whether training is needed ('is_train'),
# 'graph': The network topology for sampling.
# 'args.fanout': Number of neighbors to sample per node.
# [Output]:
# A NeighborSampler object to sample neighbors.
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout)
############################################################################
# [Step-3]:
# self.fetch_feature()
# [Input]:
# 'features': The node features.
# 'node_feature_keys': The keys of the node features to be fetched.
# [Output]:
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs.
############################################################################
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
# [Step-4]:
# gb.MultiProcessDataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
# 'args.num_workers': The number of processes to be used for data loading.
# [Output]:
# A MultiProcessDataLoader object to handle data loading.
# [Role]:
# Initialize a multi-process dataloader to load the data in parallel.
############################################################################
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=args.num_workers
)
# Return the fully-initialized DataLoader object.
return dataloader
@torch.no_grad()
def evaluate(args, model, graph, features, itemset, num_classes):
model.eval()
y = []
y_hats = []
dataloader = create_dataloader(
args, graph, features, itemset, is_train=False
)
for step, data in tqdm.tqdm(enumerate(dataloader)):
blocks = data.to_dgl_blocks()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(blocks, x))
res = MF.accuracy(
torch.cat(y_hats),
torch.cat(y),
task="multiclass",
num_classes=num_classes,
)
return res
def train(args, graph, features, train_set, valid_set, num_classes, model):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(
args, graph, features, train_set, is_train=True
)
for epoch in tqdm.trange(args.epochs):
model.train()
total_loss = 0
for step, data in tqdm.tqdm(enumerate(dataloader)):
# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
# The ground truth labels from the destination nodes
# in the last layer's computation graph.
y = data.labels
# TODO[Mingbang]: Move the to_dgl_blocks() to a datapipe stage later
# The predicted labels.
y_hat = model(data.to_dgl_blocks(), x)
# Compute loss.
loss = F.cross_entropy(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# Evaluate the model.
print("Validating...")
acc = evaluate(args, model, graph, features, valid_set, num_classes)
print(
f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | "
f"Accuracy {acc.item():.4f} "
)
def parse_args():
parser = argparse.ArgumentParser(
description="A script trains and tests a GraphSAGE model "
"for node classification using GraphBolt dataloader."
)
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs."
)
parser.add_argument(
"--lr",
type=float,
default=0.0005,
help="Learning rate for optimization.",
)
parser.add_argument(
"--batch-size", type=int, default=256, help="Batch size for training."
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for data loading.",
)
parser.add_argument(
"--fanout",
type=str,
default="15,10,5",
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
return parser.parse_args()
def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
# Load and preprocess dataset.
dataset = gb.BuiltinDataset("ogbn-products").load()
graph = dataset.graph
features = dataset.feature
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]
# TODO[Mingbang]: Replace this with a more elegant API.
in_size = features.read("node", None, "feat").shape[
1
] # Size of feature of a single node.
hidden_size = 128
out_size = num_classes
model = SAGE(in_size, hidden_size, out_size)
# Model training.
print("Training...")
train(args, graph, features, train_set, valid_set, num_classes, model)
# Test the model.
print("Testing...")
test_set = dataset.tasks[0].test_set
test_acc = evaluate(
args, model, graph, features, itemset=test_set, num_classes=num_classes
)
print(f"Test Accuracy is {test_acc.item():.4f}")
if __name__ == "__main__":
args = parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment