Unverified Commit e82f1746 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Graphbolt] Enable Link Prediction example model.inference to accelerate the...


[Graphbolt] Enable Link Prediction example model.inference to accelerate the evaluation speed (#6557)
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-12-115.ap-northeast-1.compute.internal>
parent 7d43c769
......@@ -42,8 +42,8 @@ main
└───> Test set evaluation
"""
import argparse
import time
import dgl
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
......@@ -78,6 +78,40 @@ class SAGE(nn.Module):
hidden_x = F.relu(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader, device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")
buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
print("Start node embedding inference.")
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.hidden_size,
dtype=torch.float32,
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our output nodes are contiguous.
y[
data.output_nodes[0] : data.output_nodes[-1] + 1
] = hidden_x.to(buffer_device, non_blocking=True)
feature = y
return y
def create_dataloader(args, graph, features, itemset, is_train=True):
"""Get a GraphBolt version of a dataloader for link prediction tasks. This
......@@ -166,8 +200,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs.
# subgraphs. This step is skipped in evaluation/inference because features
# are updated as a whole during it, thus storing features in minibatch is
# unnecessary.
############################################################################
if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
......@@ -223,60 +260,81 @@ def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
@torch.no_grad()
def evaluate(args, graph, features, itemset, model):
evaluator = Evaluator(name="ogbl-citation2")
def compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst):
"""Compute the Mean Reciprocal Rank (MRR) for given source and destination
nodes.
# Since we need to evaluate the model, we need to set the number
# of layers to 3 and the fanout to -1.
args.fanout = [-1] * 3
dataloader = create_dataloader(
args, graph, features, itemset, is_train=False
This function computes the MRR for a set of node pairs, dividing the task
into batches to handle potentially large graphs.
"""
rr = torch.zeros(src.shape[0])
# Loop over node pairs in batches.
for start in tqdm.trange(
0, src.shape[0], args.eval_batch_size, desc="Evaluate"
):
end = min(start + args.eval_batch_size, src.shape[0])
# Concatenate positive and negative destination nodes.
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
# Fetch embeddings for current batch of source and destination nodes.
h_src = node_emb[src[start:end]][:, None, :].to(args.device)
h_dst = (
node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(args.device)
)
pos_pred = []
neg_pred = []
# Compute prediction scores using the model.
pred = model.predictor(h_src * h_dst).squeeze(-1)
# Evaluate the predictions to obtain MRR values.
input_dict = {"y_pred_pos": pred[:, 0], "y_pred_neg": pred[:, 1:]}
rr[start:end] = evaluator.eval(input_dict)["mrr_list"]
return rr.mean()
@torch.no_grad()
def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
"""Evaluate the model on validation and test sets."""
model.eval()
for step, data in tqdm.tqdm(enumerate(dataloader)):
# Unpack MiniBatch.
compacted_pairs, _ = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"].float()
blocks = data.blocks
evaluator = Evaluator(name="ogbl-citation2")
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
# Calculate the score for positive and negative edges.
score = (
model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])
.squeeze()
.detach()
# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
args.fanout = [-1]
dataloader = create_dataloader(
args, graph, features, all_nodes_set, is_train=False
)
# Split the score into positive and negative parts.
pos_score = score[: data.positive_node_pairs[0].shape[0]]
neg_score = score[data.positive_node_pairs[0].shape[0] :]
# Compute node embeddings for the entire graph.
node_emb = model.inference(graph, features, dataloader, args.device)
results = []
# Append the score to the list.
pos_pred.append(pos_score)
neg_pred.append(neg_score)
pos_pred = torch.cat(pos_pred, dim=0)
neg_pred = torch.cat(neg_pred, dim=0).view(pos_pred.shape[0], -1)
# Loop over both validation and test sets.
for split in [valid_set, test_set]:
# Unpack the item set.
src = split._items[0][:, 0].to(node_emb.device)
dst = split._items[0][:, 1].to(node_emb.device)
neg_dst = split._items[1].to(node_emb.device)
input_dict = {"y_pred_pos": pos_pred, "y_pred_neg": neg_pred}
mrr = evaluator.eval(input_dict)["mrr_list"]
return mrr.mean()
# Compute MRR values for the current split.
results.append(
compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst)
)
return results
def train(args, graph, features, train_set, valid_set, model):
def train(args, model, graph, features, train_set):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(args, graph, features, train_set)
for epoch in tqdm.trange(args.epochs):
model.train()
total_loss = 0
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"].float()
node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks.
blocks = data.blocks
......@@ -293,20 +351,15 @@ def train(args, graph, features, train_set, valid_set, model):
optimizer.step()
total_loss += loss.item()
if (step % 100 == 0) and (step != 0):
print(
f"Epoch {epoch:05d} | "
f"Step {step:05d} | "
f"Loss {(total_loss) / (step + 1):.4f}",
end="\n",
)
if step + 1 == args.early_stop:
break
# Evaluate the model.
print("Validation")
valid_mrr = evaluate(args, graph, features, valid_set, model)
print(f"Valid MRR {valid_mrr.item():.4f}")
end_epoch_time = time.time()
print(
f"Epoch {epoch:05d} | "
f"Loss {(total_loss) / (step + 1):.4f} | "
f"Time {(end_epoch_time - start_epoch_time):.4f} s"
)
def parse_args():
......@@ -315,10 +368,7 @@ def parse_args():
parser.add_argument("--lr", type=float, default=0.0005)
parser.add_argument("--neg-ratio", type=int, default=1)
parser.add_argument("--train-batch-size", type=int, default=512)
# TODO [Issue#6534]: Use model.inference instead of dataloader to evaluate.
# Since neg_ratio in valid/test set is 1000, which is too large to GPU
# memory, we should use small batch size to evaluate.
parser.add_argument("--eval-batch-size", type=int, default=2)
parser.add_argument("--eval-batch-size", type=int, default=1024)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument(
"--early-stop",
......@@ -352,7 +402,6 @@ def main(args):
graph = dataset.graph
features = dataset.feature
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
args.fanout = list(map(int, args.fanout.split(",")))
in_size = features.size("node", None, "feat")[0]
......@@ -362,13 +411,20 @@ def main(args):
# Model training.
print("Training...")
train(args, graph, features, train_set, valid_set, model)
train(args, model, graph, features, train_set)
# Test the model.
print("Testing...")
test_set = dataset.tasks[0].test_set
test_mrr = evaluate(args, graph, features, test_set, model)
print(f"Test MRR {test_mrr.item():.4f}")
valid_set = dataset.tasks[0].validation_set
all_nodes_set = dataset.all_nodes_set
valid_mrr, test_mrr = evaluate(
args, model, graph, features, all_nodes_set, valid_set, test_set
)
print(
f"Validation MRR {valid_mrr.item():.4f}, "
f"Test MRR {test_mrr.item():.4f}"
)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment