Unverified Commit a8e672e7 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Benchmark] Align the DGL/Graphbolt link pred examples (#6609)

parent 77ec365d
...@@ -189,7 +189,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -189,7 +189,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# to ensure that positive samples are not inadvertently included within # to ensure that positive samples are not inadvertently included within
# the negative samples. # the negative samples.
############################################################################ ############################################################################
if is_train: if is_train and args.exclude_edges:
datapipe = datapipe.transform(gb.exclude_seed_edges) datapipe = datapipe.transform(gb.exclude_seed_edges)
############################################################################ ############################################################################
...@@ -369,7 +369,7 @@ def parse_args(): ...@@ -369,7 +369,7 @@ def parse_args():
parser.add_argument("--neg-ratio", type=int, default=1) parser.add_argument("--neg-ratio", type=int, default=1)
parser.add_argument("--train-batch-size", type=int, default=512) parser.add_argument("--train-batch-size", type=int, default=512)
parser.add_argument("--eval-batch-size", type=int, default=1024) parser.add_argument("--eval-batch-size", type=int, default=1024)
parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument( parser.add_argument(
"--early-stop", "--early-stop",
type=int, type=int,
...@@ -382,6 +382,12 @@ def parse_args(): ...@@ -382,6 +382,12 @@ def parse_args():
default="15,10,5", default="15,10,5",
help="Fan-out of neighbor sampling. Default: 15,10,5", help="Fan-out of neighbor sampling. Default: 15,10,5",
) )
parser.add_argument(
"--exclude-edges",
type=int,
default=1,
help="Whether to exclude reverse edges during sampling. Default: 1",
)
parser.add_argument( parser.add_argument(
"--device", "--device",
default="cpu", default="cpu",
......
...@@ -283,7 +283,9 @@ def evaluate(device, graph, edge_split, model, batch_size): ...@@ -283,7 +283,9 @@ def evaluate(device, graph, edge_split, model, batch_size):
return results return results
def train(args, device, g, reverse_eids, seed_edges, model, use_uva): def train(
args, device, g, reverse_eids, seed_edges, model, use_uva, fused_sampling
):
##################################################################### #####################################################################
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient # (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# training of Graph Neural Networks (GNNs) on large-scale graphs. # training of Graph Neural Networks (GNNs) on large-scale graphs.
...@@ -320,11 +322,15 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva): ...@@ -320,11 +322,15 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
# not just to learn node representations, but also to predict the # not just to learn node representations, but also to predict the
# likelihood of an edge existing between two nodes (link prediction). # likelihood of an edge existing between two nodes (link prediction).
##################################################################### #####################################################################
sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=["feat"]) sampler = NeighborSampler(
[15, 10, 5],
prefetch_node_feats=["feat"],
fused=fused_sampling,
)
sampler = as_edge_prediction_sampler( sampler = as_edge_prediction_sampler(
sampler, sampler,
exclude="reverse_id", exclude="reverse_id" if args.exclude_edges else None,
reverse_eids=reverse_eids, reverse_eids=reverse_eids if args.exclude_edges else None,
negative_sampler=negative_sampler.Uniform(1), negative_sampler=negative_sampler.Uniform(1),
) )
...@@ -333,7 +339,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva): ...@@ -333,7 +339,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
seed_edges, seed_edges,
sampler, sampler,
device=device, device=device,
batch_size=args.batch_size, batch_size=args.train_batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
# If `g` is on gpu or `use_uva` is True, `num_workers` must be zero, # If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,
...@@ -342,7 +348,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva): ...@@ -342,7 +348,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
use_uva=use_uva, use_uva=use_uva,
) )
opt = torch.optim.Adam(model.parameters(), lr=args.lr) opt = torch.optim.Adam(model.parameters(), lr=args.lr)
for epoch in range(10): for epoch in range(args.epochs):
model.train() model.train()
total_loss = 0 total_loss = 0
# A block is a graph consisting of two sets of nodes: the # A block is a graph consisting of two sets of nodes: the
...@@ -377,6 +383,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva): ...@@ -377,6 +383,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument( parser.add_argument(
"--lr", "--lr",
type=float, type=float,
...@@ -384,10 +391,16 @@ def parse_args(): ...@@ -384,10 +391,16 @@ def parse_args():
help="Learning rate. Default: 0.0005", help="Learning rate. Default: 0.0005",
) )
parser.add_argument( parser.add_argument(
"--batch-size", "--train-batch-size",
type=int, type=int,
default=512, default=512,
help="Batch size. Default: 512", help="Batch size for training. Default: 512",
)
parser.add_argument(
"--eval-batch-size",
type=int,
default=1024,
help="Batch size during evaluation. Default: 1024",
) )
parser.add_argument( parser.add_argument(
"--early-stop", "--early-stop",
...@@ -395,6 +408,17 @@ def parse_args(): ...@@ -395,6 +408,17 @@ def parse_args():
default=0, default=0,
help="0 means no early stop, otherwise stop at the input-th step", help="0 means no early stop, otherwise stop at the input-th step",
) )
parser.add_argument(
"--exclude-edges",
type=int,
default=1,
help="Whether to exclude reverse edges during sampling. Default: 1",
)
parser.add_argument(
"--compare-graphbolt",
action="store_true",
help="Compare with GraphBolt",
)
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="mixed", default="mixed",
...@@ -414,7 +438,11 @@ def main(args): ...@@ -414,7 +438,11 @@ def main(args):
print("Loading data") print("Loading data")
dataset = DglLinkPropPredDataset("ogbl-citation2") dataset = DglLinkPropPredDataset("ogbl-citation2")
g = dataset[0] g = dataset[0]
g = g.to("cuda" if args.mode == "puregpu" else "cpu") if args.compare_graphbolt:
fused_sampling = False
else:
fused_sampling = True
g = g.to("cuda" if args.mode == "puregpu" else "cpu")
# Whether use Unified Virtual Addressing (UVA) for CUDA computation. # Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva = args.mode == "mixed" use_uva = args.mode == "mixed"
...@@ -432,12 +460,21 @@ def main(args): ...@@ -432,12 +460,21 @@ def main(args):
# Model training. # Model training.
print("Training...") print("Training...")
train(args, device, g, reverse_eids, seed_edges, model, use_uva) train(
args,
device,
g,
reverse_eids,
seed_edges,
model,
use_uva,
fused_sampling,
)
# Validate/Test the model. # Validate/Test the model.
print("Validation/Testing...") print("Validation/Testing...")
valid_mrr, test_mrr = evaluate( valid_mrr, test_mrr = evaluate(
device, g, edge_split, model, batch_size=1000 device, g, edge_split, model, batch_size=args.eval_batch_size
) )
print( print(
f"Validation MRR {valid_mrr.item():.4f}, Test MRR {test_mrr.item():.4f}" f"Validation MRR {valid_mrr.item():.4f}, Test MRR {test_mrr.item():.4f}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment