Unverified Commit 75804a7d authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable the model to conduct layer-wise inference (#6482)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
parent 03406bd9
......@@ -35,10 +35,12 @@ main
│ │
│ └───> Validation set evaluation
└───> Test set evaluation
└───> All nodes set inference & Test set evaluation
"""
import argparse
from typing import Literal
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
......@@ -48,43 +50,29 @@ import torchmetrics.functional as MF
import tqdm
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float64)
def set_layer_dtype(self, dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(dtype)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def create_dataloader(args, graph, features, itemset, is_train=True):
def create_dataloader(
args, graph, features, itemset, job: Literal["train", "evaluate", "infer"]
):
"""
[HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks.
This function demonstrates how to utilize functional forms of datapipes in
GraphBolt.
Alternatively, you can create a datapipe using its class constructor.
Parameters
----------
args : Namespace
The arguments parsed by `parser.parse_args()`.
graph : SamplingGraph
The network topology for sampling.
features : FeatureStore
The node features.
itemset : Union[ItemSet, ItemSetDict]
Data to be sampled.
job : Literal["train", "evaluate", "infer"]
The stage where dataloader is created, with options "train", "evaluate"
and "infer".
"""
############################################################################
......@@ -96,7 +84,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.)
# 'is_train': Determining if data should be shuffled. (Shuffling is
# 'job': Determines whether data should be shuffled. (Shuffling is
# generally used only in training to improve model generalization. It's
# not used in validation and testing as the focus there is to evaluate
# performance rather than to learn from the data.)
......@@ -106,23 +94,26 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################
datapipe = gb.ItemSampler(
itemset, batch_size=args.batch_size, shuffle=is_train
itemset, batch_size=args.batch_size, shuffle=(job == "train")
)
############################################################################
# [Step-2]:
# self.sample_neighbor()
# [Input]:
# 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending
# on whether training is needed ('is_train'),
# 'graph': The network topology for sampling.
# 'args.fanout': Number of neighbors to sample per node.
# '[-1] or args.fanout': Number of neighbors to sample per node. In
# training or validation, the length of args.fanout should be equal to the
# number of layers in the model. In inference, this parameter is set to
# [-1], indicating that all neighbors of a node are sampled.
# [Output]:
# A NeighborSampler object to sample neighbors.
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.sample_neighbor(
graph, args.fanout if job != "infer" else [-1]
)
############################################################################
# [Step-3]:
......@@ -134,8 +125,10 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs.
# subgraphs. This step is skipped in inference because features are updated
# as a whole during it, thus storing features in minibatch is unnecessary.
############################################################################
if job != "infer":
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
......@@ -169,13 +162,88 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
return dataloader
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float64)
def set_layer_dtype(self, _dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(_dtype)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float64,
)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.output_nodes[0] : data.output_nodes[-1] + 1] = hidden_x
feature = y
return y
@torch.no_grad()
def layerwise_infer(
args, graph, features, test_set, all_nodes_set, model, num_classes
):
model.eval()
dataloader = create_dataloader(
args, graph, features, all_nodes_set, job="infer"
)
pred = model.inference(graph, features, dataloader)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)
return MF.accuracy(
pred,
label,
task="multiclass",
num_classes=num_classes,
)
@torch.no_grad()
def evaluate(args, model, graph, features, itemset, num_classes):
model.eval()
y = []
y_hats = []
dataloader = create_dataloader(
args, graph, features, itemset, is_train=False
args, graph, features, itemset, job="evaluate"
)
for step, data in tqdm.tqdm(enumerate(dataloader)):
......@@ -183,20 +251,18 @@ def evaluate(args, model, graph, features, itemset, num_classes):
y.append(data.labels)
y_hats.append(model(data.blocks, x))
res = MF.accuracy(
return MF.accuracy(
torch.cat(y_hats),
torch.cat(y),
task="multiclass",
num_classes=num_classes,
)
return res
def train(args, graph, features, train_set, valid_set, num_classes, model):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(
args, graph, features, train_set, is_train=True
args, graph, features, train_set, job="train"
)
for epoch in tqdm.trange(args.epochs):
......@@ -223,7 +289,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
total_loss += loss.item()
# Evaluate the model.
print("Validating...")
acc = evaluate(args, model, graph, features, valid_set, num_classes)
print(
f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | "
......@@ -261,20 +326,10 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
return parser.parse_args()
def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
# Load and preprocess dataset.
dataset = gb.BuiltinDataset("ogbn-products").load()
......@@ -282,6 +337,8 @@ def main(args):
features = dataset.feature
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
all_nodes_set = dataset.all_nodes_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]
......@@ -298,9 +355,14 @@ def main(args):
# Test the model.
print("Testing...")
test_set = dataset.tasks[0].test_set
test_acc = evaluate(
args, model, graph, features, itemset=test_set, num_classes=num_classes
test_acc = layerwise_infer(
args,
graph,
features,
test_set,
all_nodes_set,
model,
num_classes,
)
print(f"Test Accuracy is {test_acc.item():.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment