"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "72eae64d67a4c8430aced38faaf8c09f79012469"
Unverified Commit e82f1746 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Graphbolt] Enable Link Prediction example model.inference to accelerate the...


[Graphbolt] Enable Link Prediction example model.inference to accelerate the evaluation speed (#6557)
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-12-115.ap-northeast-1.compute.internal>
parent 7d43c769
...@@ -42,8 +42,8 @@ main ...@@ -42,8 +42,8 @@ main
└───> Test set evaluation └───> Test set evaluation
""" """
import argparse import argparse
import time
import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
import dgl.nn as dglnn import dgl.nn as dglnn
import torch import torch
...@@ -78,6 +78,40 @@ class SAGE(nn.Module): ...@@ -78,6 +78,40 @@ class SAGE(nn.Module):
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
return hidden_x return hidden_x
def inference(self, graph, features, dataloader, device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")
buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
print("Start node embedding inference.")
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.hidden_size,
dtype=torch.float32,
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our output nodes are contiguous.
y[
data.output_nodes[0] : data.output_nodes[-1] + 1
] = hidden_x.to(buffer_device, non_blocking=True)
feature = y
return y
def create_dataloader(args, graph, features, itemset, is_train=True): def create_dataloader(args, graph, features, itemset, is_train=True):
"""Get a GraphBolt version of a dataloader for link prediction tasks. This """Get a GraphBolt version of a dataloader for link prediction tasks. This
...@@ -166,9 +200,12 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -166,9 +200,12 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features. # A FeatureFetcher object to fetch node features.
# [Role]: # [Role]:
# Initialize a feature fetcher for fetching features of the sampled # Initialize a feature fetcher for fetching features of the sampled
# subgraphs. # subgraphs. This step is skipped in evaluation/inference because features
# are updated as a whole during it, thus storing features in minibatch is
# unnecessary.
############################################################################ ############################################################################
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################ ############################################################################
# [Step-4]: # [Step-4]:
...@@ -223,60 +260,81 @@ def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch): ...@@ -223,60 +260,81 @@ def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
@torch.no_grad() @torch.no_grad()
def evaluate(args, graph, features, itemset, model): def compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst):
"""Compute the Mean Reciprocal Rank (MRR) for given source and destination
nodes.
This function computes the MRR for a set of node pairs, dividing the task
into batches to handle potentially large graphs.
"""
rr = torch.zeros(src.shape[0])
# Loop over node pairs in batches.
for start in tqdm.trange(
0, src.shape[0], args.eval_batch_size, desc="Evaluate"
):
end = min(start + args.eval_batch_size, src.shape[0])
# Concatenate positive and negative destination nodes.
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
# Fetch embeddings for current batch of source and destination nodes.
h_src = node_emb[src[start:end]][:, None, :].to(args.device)
h_dst = (
node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(args.device)
)
# Compute prediction scores using the model.
pred = model.predictor(h_src * h_dst).squeeze(-1)
# Evaluate the predictions to obtain MRR values.
input_dict = {"y_pred_pos": pred[:, 0], "y_pred_neg": pred[:, 1:]}
rr[start:end] = evaluator.eval(input_dict)["mrr_list"]
return rr.mean()
@torch.no_grad()
def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
"""Evaluate the model on validation and test sets."""
model.eval()
evaluator = Evaluator(name="ogbl-citation2") evaluator = Evaluator(name="ogbl-citation2")
# Since we need to evaluate the model, we need to set the number # Since we need to use all neghborhoods for evaluation, we set the fanout
# of layers to 3 and the fanout to -1. # to -1.
args.fanout = [-1] * 3 args.fanout = [-1]
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, itemset, is_train=False args, graph, features, all_nodes_set, is_train=False
) )
pos_pred = []
neg_pred = []
model.eval()
for step, data in tqdm.tqdm(enumerate(dataloader)):
# Unpack MiniBatch.
compacted_pairs, _ = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"].float()
blocks = data.blocks
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
# Calculate the score for positive and negative edges.
score = (
model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])
.squeeze()
.detach()
)
# Split the score into positive and negative parts. # Compute node embeddings for the entire graph.
pos_score = score[: data.positive_node_pairs[0].shape[0]] node_emb = model.inference(graph, features, dataloader, args.device)
neg_score = score[data.positive_node_pairs[0].shape[0] :] results = []
# Append the score to the list. # Loop over both validation and test sets.
pos_pred.append(pos_score) for split in [valid_set, test_set]:
neg_pred.append(neg_score) # Unpack the item set.
pos_pred = torch.cat(pos_pred, dim=0) src = split._items[0][:, 0].to(node_emb.device)
neg_pred = torch.cat(neg_pred, dim=0).view(pos_pred.shape[0], -1) dst = split._items[0][:, 1].to(node_emb.device)
neg_dst = split._items[1].to(node_emb.device)
input_dict = {"y_pred_pos": pos_pred, "y_pred_neg": neg_pred} # Compute MRR values for the current split.
mrr = evaluator.eval(input_dict)["mrr_list"] results.append(
return mrr.mean() compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst)
)
return results
def train(args, graph, features, train_set, valid_set, model): def train(args, model, graph, features, train_set):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(args, graph, features, train_set) dataloader = create_dataloader(args, graph, features, train_set)
for epoch in tqdm.trange(args.epochs): for epoch in tqdm.trange(args.epochs):
model.train() model.train()
total_loss = 0 total_loss = 0
start_epoch_time = time.time()
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Unpack MiniBatch. # Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data) compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"].float() node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks. # Convert sampled subgraphs to DGL blocks.
blocks = data.blocks blocks = data.blocks
...@@ -293,20 +351,15 @@ def train(args, graph, features, train_set, valid_set, model): ...@@ -293,20 +351,15 @@ def train(args, graph, features, train_set, valid_set, model):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
if (step % 100 == 0) and (step != 0):
print(
f"Epoch {epoch:05d} | "
f"Step {step:05d} | "
f"Loss {(total_loss) / (step + 1):.4f}",
end="\n",
)
if step + 1 == args.early_stop: if step + 1 == args.early_stop:
break break
# Evaluate the model. end_epoch_time = time.time()
print("Validation") print(
valid_mrr = evaluate(args, graph, features, valid_set, model) f"Epoch {epoch:05d} | "
print(f"Valid MRR {valid_mrr.item():.4f}") f"Loss {(total_loss) / (step + 1):.4f} | "
f"Time {(end_epoch_time - start_epoch_time):.4f} s"
)
def parse_args(): def parse_args():
...@@ -315,10 +368,7 @@ def parse_args(): ...@@ -315,10 +368,7 @@ def parse_args():
parser.add_argument("--lr", type=float, default=0.0005) parser.add_argument("--lr", type=float, default=0.0005)
parser.add_argument("--neg-ratio", type=int, default=1) parser.add_argument("--neg-ratio", type=int, default=1)
parser.add_argument("--train-batch-size", type=int, default=512) parser.add_argument("--train-batch-size", type=int, default=512)
# TODO [Issue#6534]: Use model.inference instead of dataloader to evaluate. parser.add_argument("--eval-batch-size", type=int, default=1024)
# Since neg_ratio in valid/test set is 1000, which is too large to GPU
# memory, we should use small batch size to evaluate.
parser.add_argument("--eval-batch-size", type=int, default=2)
parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument( parser.add_argument(
"--early-stop", "--early-stop",
...@@ -352,7 +402,6 @@ def main(args): ...@@ -352,7 +402,6 @@ def main(args):
graph = dataset.graph graph = dataset.graph
features = dataset.feature features = dataset.feature
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
args.fanout = list(map(int, args.fanout.split(","))) args.fanout = list(map(int, args.fanout.split(",")))
in_size = features.size("node", None, "feat")[0] in_size = features.size("node", None, "feat")[0]
...@@ -362,13 +411,20 @@ def main(args): ...@@ -362,13 +411,20 @@ def main(args):
# Model training. # Model training.
print("Training...") print("Training...")
train(args, graph, features, train_set, valid_set, model) train(args, model, graph, features, train_set)
# Test the model. # Test the model.
print("Testing...") print("Testing...")
test_set = dataset.tasks[0].test_set test_set = dataset.tasks[0].test_set
test_mrr = evaluate(args, graph, features, test_set, model) valid_set = dataset.tasks[0].validation_set
print(f"Test MRR {test_mrr.item():.4f}") all_nodes_set = dataset.all_nodes_set
valid_mrr, test_mrr = evaluate(
args, model, graph, features, all_nodes_set, valid_set, test_set
)
print(
f"Validation MRR {valid_mrr.item():.4f}, "
f"Test MRR {test_mrr.item():.4f}"
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment