Unverified Commit 75804a7d authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable the model to conduct layer-wise inference (#6482)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
parent 03406bd9
...@@ -35,10 +35,12 @@ main ...@@ -35,10 +35,12 @@ main
│ │ │ │
│ └───> Validation set evaluation │ └───> Validation set evaluation
└───> Test set evaluation └───> All nodes set inference & Test set evaluation
""" """
import argparse import argparse
from typing import Literal
import dgl.graphbolt as gb import dgl.graphbolt as gb
import dgl.nn as dglnn import dgl.nn as dglnn
import torch import torch
...@@ -48,43 +50,29 @@ import torchmetrics.functional as MF ...@@ -48,43 +50,29 @@ import torchmetrics.functional as MF
import tqdm import tqdm
class SAGE(nn.Module): def create_dataloader(
def __init__(self, in_size, hidden_size, out_size): args, graph, features, itemset, job: Literal["train", "evaluate", "infer"]
super().__init__() ):
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float64)
def set_layer_dtype(self, dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(dtype)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def create_dataloader(args, graph, features, itemset, is_train=True):
""" """
[HIGHLIGHT] [HIGHLIGHT]
Get a GraphBolt version of a dataloader for node classification tasks. Get a GraphBolt version of a dataloader for node classification tasks.
This function demonstrates how to utilize functional forms of datapipes in This function demonstrates how to utilize functional forms of datapipes in
GraphBolt. GraphBolt.
Alternatively, you can create a datapipe using its class constructor. Alternatively, you can create a datapipe using its class constructor.
Parameters
----------
args : Namespace
The arguments parsed by `parser.parse_args()`.
graph : SamplingGraph
The network topology for sampling.
features : FeatureStore
The node features.
itemset : Union[ItemSet, ItemSetDict]
Data to be sampled.
job : Literal["train", "evaluate", "infer"]
The stage where dataloader is created, with options "train", "evaluate"
and "infer".
""" """
############################################################################ ############################################################################
...@@ -96,7 +84,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -96,7 +84,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# referred to as a 'mini-batch'. (The term 'mini-batch' is used here to # referred to as a 'mini-batch'. (The term 'mini-batch' is used here to
# indicate a subset of the entire dataset that is processed together. This # indicate a subset of the entire dataset that is processed together. This
# is in contrast to processing the entire dataset, known as a 'full batch'.) # is in contrast to processing the entire dataset, known as a 'full batch'.)
# 'is_train': Determining if data should be shuffled. (Shuffling is # 'job': Determines whether data should be shuffled. (Shuffling is
# generally used only in training to improve model generalization. It's # generally used only in training to improve model generalization. It's
# not used in validation and testing as the focus there is to evaluate # not used in validation and testing as the focus there is to evaluate
# performance rather than to learn from the data.) # performance rather than to learn from the data.)
...@@ -106,23 +94,26 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -106,23 +94,26 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# Initialize the ItemSampler to sample mini-batche from the dataset. # Initialize the ItemSampler to sample mini-batche from the dataset.
############################################################################ ############################################################################
datapipe = gb.ItemSampler( datapipe = gb.ItemSampler(
itemset, batch_size=args.batch_size, shuffle=is_train itemset, batch_size=args.batch_size, shuffle=(job == "train")
) )
############################################################################ ############################################################################
# [Step-2]: # [Step-2]:
# self.sample_neighbor() # self.sample_neighbor()
# [Input]: # [Input]:
# 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending
# on whether training is needed ('is_train'),
# 'graph': The network topology for sampling. # 'graph': The network topology for sampling.
# 'args.fanout': Number of neighbors to sample per node. # '[-1] or args.fanout': Number of neighbors to sample per node. In
# training or validation, the length of args.fanout should be equal to the
# number of layers in the model. In inference, this parameter is set to
# [-1], indicating that all neighbors of a node are sampled.
# [Output]: # [Output]:
# A NeighborSampler object to sample neighbors. # A NeighborSampler object to sample neighbors.
# [Role]: # [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes. # Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################ ############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout) datapipe = datapipe.sample_neighbor(
graph, args.fanout if job != "infer" else [-1]
)
############################################################################ ############################################################################
# [Step-3]: # [Step-3]:
...@@ -134,9 +125,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -134,9 +125,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features. # A FeatureFetcher object to fetch node features.
# [Role]: # [Role]:
# Initialize a feature fetcher for fetching features of the sampled # Initialize a feature fetcher for fetching features of the sampled
# subgraphs. # subgraphs. This step is skipped in inference because features are updated
# as a whole during it, thus storing features in minibatch is unnecessary.
############################################################################ ############################################################################
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) if job != "infer":
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################ ############################################################################
# [Step-4]: # [Step-4]:
...@@ -169,13 +162,88 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -169,13 +162,88 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
return dataloader return dataloader
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float64)
def set_layer_dtype(self, _dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(_dtype)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float64,
)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.output_nodes[0] : data.output_nodes[-1] + 1] = hidden_x
feature = y
return y
@torch.no_grad()
def layerwise_infer(
args, graph, features, test_set, all_nodes_set, model, num_classes
):
model.eval()
dataloader = create_dataloader(
args, graph, features, all_nodes_set, job="infer"
)
pred = model.inference(graph, features, dataloader)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)
return MF.accuracy(
pred,
label,
task="multiclass",
num_classes=num_classes,
)
@torch.no_grad() @torch.no_grad()
def evaluate(args, model, graph, features, itemset, num_classes): def evaluate(args, model, graph, features, itemset, num_classes):
model.eval() model.eval()
y = [] y = []
y_hats = [] y_hats = []
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, itemset, is_train=False args, graph, features, itemset, job="evaluate"
) )
for step, data in tqdm.tqdm(enumerate(dataloader)): for step, data in tqdm.tqdm(enumerate(dataloader)):
...@@ -183,20 +251,18 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -183,20 +251,18 @@ def evaluate(args, model, graph, features, itemset, num_classes):
y.append(data.labels) y.append(data.labels)
y_hats.append(model(data.blocks, x)) y_hats.append(model(data.blocks, x))
res = MF.accuracy( return MF.accuracy(
torch.cat(y_hats), torch.cat(y_hats),
torch.cat(y), torch.cat(y),
task="multiclass", task="multiclass",
num_classes=num_classes, num_classes=num_classes,
) )
return res
def train(args, graph, features, train_set, valid_set, num_classes, model): def train(args, graph, features, train_set, valid_set, num_classes, model):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, train_set, is_train=True args, graph, features, train_set, job="train"
) )
for epoch in tqdm.trange(args.epochs): for epoch in tqdm.trange(args.epochs):
...@@ -223,7 +289,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -223,7 +289,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
total_loss += loss.item() total_loss += loss.item()
# Evaluate the model. # Evaluate the model.
print("Validating...")
acc = evaluate(args, model, graph, features, valid_set, num_classes) acc = evaluate(args, model, graph, features, valid_set, num_classes)
print( print(
f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | " f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | "
...@@ -261,20 +326,10 @@ def parse_args(): ...@@ -261,20 +326,10 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)" help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5", " identical with the number of layers in your model. Default: 15,10,5",
) )
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
return parser.parse_args() return parser.parse_args()
def main(args): def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
# Load and preprocess dataset. # Load and preprocess dataset.
dataset = gb.BuiltinDataset("ogbn-products").load() dataset = gb.BuiltinDataset("ogbn-products").load()
...@@ -282,6 +337,8 @@ def main(args): ...@@ -282,6 +337,8 @@ def main(args):
features = dataset.feature features = dataset.feature
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
all_nodes_set = dataset.all_nodes_set
args.fanout = list(map(int, args.fanout.split(","))) args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"] num_classes = dataset.tasks[0].metadata["num_classes"]
...@@ -298,9 +355,14 @@ def main(args): ...@@ -298,9 +355,14 @@ def main(args):
# Test the model. # Test the model.
print("Testing...") print("Testing...")
test_set = dataset.tasks[0].test_set test_acc = layerwise_infer(
test_acc = evaluate( args,
args, model, graph, features, itemset=test_set, num_classes=num_classes graph,
features,
test_set,
all_nodes_set,
model,
num_classes,
) )
print(f"Test Accuracy is {test_acc.item():.4f}") print(f"Test Accuracy is {test_acc.item():.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment