Commit 2550ac5c authored by Ramon Zhou's avatar Ramon Zhou Committed by RhettYing
Browse files

[GraphBolt][PyG] Modify PyG example with `to_pyg_data` (#7123)


Co-authored-by: default avatarMuhammed Fatih BALIN <m.f.balin@gmail.com>
parent 9d7fe9d3
""" """
This script demonstrates node classification with GraphSAGE on large graphs, This script demonstrates node classification with GraphSAGE on large graphs,
merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently manages merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently
data loading for large datasets, crucial for mini-batch processing. Post data manages data loading for large datasets, crucial for mini-batch processing.
loading, PyG's user-friendly framework takes over for training, showcasing seamless Post data loading, PyG's user-friendly framework takes over for training,
integration with GraphBolt. This combination offers an efficient alternative to showcasing seamless integration with GraphBolt. This combination offers an
traditional Deep Graph Library (DGL) methods, highlighting adaptability and efficient alternative to traditional Deep Graph Library (DGL) methods,
scalability in handling large-scale graph data for diverse real-world applications. highlighting adaptability and scalability in handling large-scale graph data
for diverse real-world applications.
Key Features: Key Features:
- Implements the GraphSAGE model, a scalable GNN, for node classification on large graphs. - Implements the GraphSAGE model, a scalable GNN, for node classification on
large graphs.
- Utilizes GraphBolt, an efficient framework for large-scale graph data processing. - Utilizes GraphBolt, an efficient framework for large-scale graph data processing.
- Integrates with PyTorch Geometric for building and training the GraphSAGE model. - Integrates with PyTorch Geometric for building and training the GraphSAGE model.
- The script is well-documented, providing clear explanations at each step. - The script is well-documented, providing clear explanations at each step.
...@@ -38,6 +38,8 @@ main ...@@ -38,6 +38,8 @@ main
│ │ │ │
│ ├───> Forward and backward passes │ ├───> Forward and backward passes
│ │ │ │
│ ├───> Convert GraphBolt MiniBatch to PyG Data
│ │
│ └───> Parameters optimization │ └───> Parameters optimization
└───> Evaluate the model └───> Evaluate the model
...@@ -56,6 +58,7 @@ import torch ...@@ -56,6 +58,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
from torch_geometric.nn import SAGEConv from torch_geometric.nn import SAGEConv
from tqdm import tqdm
class GraphSAGE(torch.nn.Module): class GraphSAGE(torch.nn.Module):
...@@ -67,6 +70,8 @@ class GraphSAGE(torch.nn.Module): ...@@ -67,6 +70,8 @@ class GraphSAGE(torch.nn.Module):
# - 'in_size', 'hidden_size', 'out_size' are the sizes of # - 'in_size', 'hidden_size', 'out_size' are the sizes of
# the input, hidden, and output features, respectively. # the input, hidden, and output features, respectively.
# - The forward method defines the computation performed at every call. # - The forward method defines the computation performed at every call.
# - It's adopted from the official PyG example which can be found at
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
##################################################################### #####################################################################
def __init__(self, in_size, hidden_size, out_size): def __init__(self, in_size, hidden_size, out_size):
super(GraphSAGE, self).__init__() super(GraphSAGE, self).__init__()
...@@ -75,87 +80,83 @@ class GraphSAGE(torch.nn.Module): ...@@ -75,87 +80,83 @@ class GraphSAGE(torch.nn.Module):
self.layers.append(SAGEConv(hidden_size, hidden_size)) self.layers.append(SAGEConv(hidden_size, hidden_size))
self.layers.append(SAGEConv(hidden_size, out_size)) self.layers.append(SAGEConv(hidden_size, out_size))
def forward(self, blocks, x, device): def forward(self, x, edge_index):
h = x for i, layer in enumerate(self.layers):
for i, (layer, block) in enumerate(zip(self.layers, blocks)): x = layer(x, edge_index)
src, dst = block.edges() if i != len(self.layers) - 1:
edge_index = torch.stack([src, dst], dim=0) x = x.relu()
h_src, h_dst = h, h[: block.number_of_dst_nodes()] x = F.dropout(x, p=0.5, training=self.training)
h = layer((h_src, h_dst), edge_index) return x
if i != len(blocks) - 1:
h = F.relu(h)
return h
def inference(self, args, dataloader, x_all, device):
"""Conduct layer-wise inference to get all the node embeddings."""
for i, layer in tqdm(enumerate(self.layers), "inference"):
xs = []
for minibatch in dataloader:
# Call `to_pyg_data` to convert GB Minibatch to PyG Data.
pyg_data = minibatch.to_pyg_data()
n_ids = minibatch.node_ids().to("cpu")
x = x_all[n_ids].to(device)
edge_index = pyg_data.edge_index
x = layer(x, edge_index)
x = x[: 4 * args.batch_size]
if i != len(self.layers) - 1:
x = x.relu()
xs.append(x.cpu())
x_all = torch.cat(xs, dim=0)
return x_all
def create_dataloader(dataset_set, graph, feature, device, is_train):
#####################################################################
# (HIGHLIGHT) Create a data loader for efficiently loading graph data.
#
# - 'ItemSampler' samples mini-batches of node IDs from the dataset.
# - 'sample_neighbor' performs neighbor sampling on the graph.
# - 'FeatureFetcher' fetches node features based on the sampled subgraph.
# - 'CopyTo' copies the fetched data to the specified device.
#####################################################################
# Create a datapipe for mini-batch sampling with a specific neighbor fanout.
# Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.
# We're using `sample_neighbor` for consistency with DGL's sampling API.
# Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,
# which could provide further optimization and efficiency for GNN training.
# Users are encouraged to explore these advanced features for potentially improved performance.
def create_dataloader(
dataset_set, graph, feature, batch_size, fanout, device, job
):
# Initialize an ItemSampler to sample mini-batches from the dataset. # Initialize an ItemSampler to sample mini-batches from the dataset.
datapipe = gb.ItemSampler( datapipe = gb.ItemSampler(
dataset_set, batch_size=1024, shuffle=is_train, drop_last=is_train dataset_set,
batch_size=batch_size,
shuffle=(job == "train"),
drop_last=(job == "train"),
) )
# Sample neighbors for each node in the mini-batch. # Sample neighbors for each node in the mini-batch.
datapipe = datapipe.sample_neighbor(graph, [10, 10, 10]) datapipe = datapipe.sample_neighbor(
graph, fanout if job != "infer" else [-1]
)
# Copy the data to the specified device.
datapipe = datapipe.copy_to(device=device, extra_attrs=["input_nodes"])
# Fetch node features for the sampled subgraph. # Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
# Copy the data to the specified device.
datapipe = datapipe.copy_to(device=device)
# Create and return a DataLoader to handle data loading. # Create and return a DataLoader to handle data loading.
dataloader = gb.DataLoader(datapipe, num_workers=0) dataloader = gb.DataLoader(datapipe, num_workers=0)
return dataloader return dataloader
def train(model, dataloader, optimizer, criterion, device, num_classes): def train(model, dataloader, optimizer):
#####################################################################
# (HIGHLIGHT) Train the model for one epoch.
#
# - Iterates over the data loader, fetching mini-batches of graph data.
# - For each mini-batch, it performs a forward pass, computes loss, and
# updates the model parameters.
# - The function returns the average loss and accuracy for the epoch.
#
# Parameters:
# model: The GraphSAGE model.
# dataloader: DataLoader that provides mini-batches of graph data.
# optimizer: Optimizer used for updating model parameters.
# criterion: Loss function used for training.
# device: The device (CPU/GPU) to run the training on.
#####################################################################
model.train() # Set the model to training mode model.train() # Set the model to training mode
total_loss = 0 # Accumulator for the total loss total_loss = 0 # Accumulator for the total loss
total_correct = 0 # Accumulator for the total number of correct predictions total_correct = 0 # Accumulator for the total number of correct predictions
total_samples = 0 # Accumulator for the total number of samples processed total_samples = 0 # Accumulator for the total number of samples processed
num_batches = 0 # Counter for the number of mini-batches processed num_batches = 0 # Counter for the number of mini-batches processed
for minibatch in dataloader: for _, minibatch in tqdm(enumerate(dataloader), "training"):
node_features = minibatch.node_features["feat"] #####################################################################
labels = minibatch.labels # (HIGHLIGHT) Convert GraphBolt MiniBatch to PyG Data class.
#
# Call `MiniBatch.to_pyg_data()` and it will return a PyG Data class
# with necessary data and information.
#####################################################################
pyg_data = minibatch.to_pyg_data()
optimizer.zero_grad() optimizer.zero_grad()
out = model(minibatch.blocks, node_features, device) out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
loss = criterion(out, labels) y = pyg_data.y
total_loss += loss.item() loss = F.cross_entropy(out, y)
total_correct += MF.accuracy(
out, labels, task="multiclass", num_classes=num_classes
) * labels.size(0)
total_samples += labels.size(0)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(y).sum())
total_samples += y.shape[0]
num_batches += 1 num_batches += 1
avg_loss = total_loss / num_batches avg_loss = total_loss / num_batches
avg_accuracy = total_correct / total_samples avg_accuracy = total_correct / total_samples
...@@ -163,16 +164,16 @@ def train(model, dataloader, optimizer, criterion, device, num_classes): ...@@ -163,16 +164,16 @@ def train(model, dataloader, optimizer, criterion, device, num_classes):
@torch.no_grad() @torch.no_grad()
def evaluate(model, dataloader, device, num_classes): def evaluate(model, dataloader, num_classes):
model.eval() model.eval()
y_hats = [] y_hats = []
ys = [] ys = []
for minibatch in dataloader: for _, minibatch in tqdm(enumerate(dataloader), "evaluating"):
node_features = minibatch.node_features["feat"] pyg_data = minibatch.to_pyg_data()
labels = minibatch.labels out = model(pyg_data.x, pyg_data.edge_index)[: pyg_data.y.shape[0]]
out = model(minibatch.blocks, node_features, device) y = pyg_data.y
y_hats.append(out) y_hats.append(out)
ys.append(labels) ys.append(y)
return MF.accuracy( return MF.accuracy(
torch.cat(y_hats), torch.cat(y_hats),
...@@ -182,6 +183,24 @@ def evaluate(model, dataloader, device, num_classes): ...@@ -182,6 +183,24 @@ def evaluate(model, dataloader, device, num_classes):
) )
@torch.no_grad()
def layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device
):
model.eval()
features = feature.read("node", None, "feat")
pred = model.inference(args, infer_dataloader, features, device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)
return MF.accuracy(
pred,
label,
task="multiclass",
num_classes=num_classes,
)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Which dataset are you going to use?" description="Which dataset are you going to use?"
...@@ -189,45 +208,71 @@ def main(): ...@@ -189,45 +208,71 @@ def main():
parser.add_argument( parser.add_argument(
"--dataset", "--dataset",
type=str, type=str,
default="ogbn-arxiv", default="ogbn-products",
help='Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")', help='Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")',
) )
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs."
)
parser.add_argument(
"--batch-size", type=int, default=1024, help="Batch size for training."
)
args = parser.parse_args() args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = args.dataset dataset_name = args.dataset
dataset = gb.BuiltinDataset(dataset_name).load() dataset = gb.BuiltinDataset(dataset_name).load()
graph = dataset.graph graph = dataset.graph
feature = dataset.feature feature = dataset.feature.pin_memory_()
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set test_set = dataset.tasks[0].test_set
all_nodes_set = dataset.all_nodes_set
num_classes = dataset.tasks[0].metadata["num_classes"] num_classes = dataset.tasks[0].metadata["num_classes"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataloader = create_dataloader( train_dataloader = create_dataloader(
train_set, graph, feature, device, is_train=True train_set,
graph,
feature,
args.batch_size,
[5, 10, 15],
device,
job="train",
) )
valid_dataloader = create_dataloader( valid_dataloader = create_dataloader(
valid_set, graph, feature, device, is_train=False valid_set,
graph,
feature,
args.batch_size,
[5, 10, 15],
device,
job="evaluate",
) )
test_dataloader = create_dataloader( infer_dataloader = create_dataloader(
test_set, graph, feature, device, is_train=False all_nodes_set,
graph,
feature,
4 * args.batch_size,
[-1],
device,
job="infer",
) )
in_channels = feature.size("node", None, "feat")[0] in_channels = feature.size("node", None, "feat")[0]
hidden_channels = 128 hidden_channels = 256
model = GraphSAGE(in_channels, hidden_channels, num_classes).to(device) model = GraphSAGE(in_channels, hidden_channels, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
criterion = torch.nn.CrossEntropyLoss() for epoch in range(args.epochs):
for epoch in range(10): train_loss, train_accuracy = train(model, train_dataloader, optimizer)
train_loss, train_accuracy = train(
model, train_dataloader, optimizer, criterion, device, num_classes
)
valid_accuracy = evaluate(model, valid_dataloader, device, num_classes) valid_accuracy = evaluate(model, valid_dataloader, num_classes)
print( print(
f"Epoch {epoch}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, " f"Epoch {epoch}, Train Loss: {train_loss:.4f}, "
f"Train Accuracy: {train_accuracy:.4f}, "
f"Valid Accuracy: {valid_accuracy:.4f}" f"Valid Accuracy: {valid_accuracy:.4f}"
) )
test_accuracy = evaluate(model, test_dataloader, device, num_classes) test_accuracy = layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device
)
print(f"Test Accuracy: {test_accuracy:.4f}") print(f"Test Accuracy: {test_accuracy:.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment