Unverified Commit a8e672e7 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Benchmark] Align the DGL/Graphbolt link pred examples (#6609)

parent 77ec365d
......@@ -189,7 +189,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# to ensure that positive samples are not inadvertently included within
# the negative samples.
############################################################################
if is_train:
if is_train and args.exclude_edges:
datapipe = datapipe.transform(gb.exclude_seed_edges)
############################################################################
......@@ -369,7 +369,7 @@ def parse_args():
parser.add_argument("--neg-ratio", type=int, default=1)
parser.add_argument("--train-batch-size", type=int, default=512)
parser.add_argument("--eval-batch-size", type=int, default=1024)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument(
"--early-stop",
type=int,
......@@ -382,6 +382,12 @@ def parse_args():
default="15,10,5",
help="Fan-out of neighbor sampling. Default: 15,10,5",
)
parser.add_argument(
"--exclude-edges",
type=int,
default=1,
help="Whether to exclude reverse edges during sampling. Default: 1",
)
parser.add_argument(
"--device",
default="cpu",
......
......@@ -283,7 +283,9 @@ def evaluate(device, graph, edge_split, model, batch_size):
return results
def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
def train(
args, device, g, reverse_eids, seed_edges, model, use_uva, fused_sampling
):
#####################################################################
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# training of Graph Neural Networks (GNNs) on large-scale graphs.
......@@ -320,11 +322,15 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
# not just to learn node representations, but also to predict the
# likelihood of an edge existing between two nodes (link prediction).
#####################################################################
sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=["feat"])
sampler = NeighborSampler(
[15, 10, 5],
prefetch_node_feats=["feat"],
fused=fused_sampling,
)
sampler = as_edge_prediction_sampler(
sampler,
exclude="reverse_id",
reverse_eids=reverse_eids,
exclude="reverse_id" if args.exclude_edges else None,
reverse_eids=reverse_eids if args.exclude_edges else None,
negative_sampler=negative_sampler.Uniform(1),
)
......@@ -333,7 +339,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
seed_edges,
sampler,
device=device,
batch_size=args.batch_size,
batch_size=args.train_batch_size,
shuffle=True,
drop_last=False,
# If `g` is on gpu or `use_uva` is True, `num_workers` must be zero,
......@@ -342,7 +348,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
use_uva=use_uva,
)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
for epoch in range(10):
for epoch in range(args.epochs):
model.train()
total_loss = 0
# A block is a graph consisting of two sets of nodes: the
......@@ -377,6 +383,7 @@ def train(args, device, g, reverse_eids, seed_edges, model, use_uva):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument(
"--lr",
type=float,
......@@ -384,10 +391,16 @@ def parse_args():
help="Learning rate. Default: 0.0005",
)
parser.add_argument(
"--batch-size",
"--train-batch-size",
type=int,
default=512,
help="Batch size. Default: 512",
help="Batch size for training. Default: 512",
)
parser.add_argument(
"--eval-batch-size",
type=int,
default=1024,
help="Batch size during evaluation. Default: 1024",
)
parser.add_argument(
"--early-stop",
......@@ -395,6 +408,17 @@ def parse_args():
default=0,
help="0 means no early stop, otherwise stop at the input-th step",
)
parser.add_argument(
"--exclude-edges",
type=int,
default=1,
help="Whether to exclude reverse edges during sampling. Default: 1",
)
parser.add_argument(
"--compare-graphbolt",
action="store_true",
help="Compare with GraphBolt",
)
parser.add_argument(
"--mode",
default="mixed",
......@@ -414,7 +438,11 @@ def main(args):
print("Loading data")
dataset = DglLinkPropPredDataset("ogbl-citation2")
g = dataset[0]
g = g.to("cuda" if args.mode == "puregpu" else "cpu")
if args.compare_graphbolt:
fused_sampling = False
else:
fused_sampling = True
g = g.to("cuda" if args.mode == "puregpu" else "cpu")
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva = args.mode == "mixed"
......@@ -432,12 +460,21 @@ def main(args):
# Model training.
print("Training...")
train(args, device, g, reverse_eids, seed_edges, model, use_uva)
train(
args,
device,
g,
reverse_eids,
seed_edges,
model,
use_uva,
fused_sampling,
)
# Validate/Test the model.
print("Validation/Testing...")
valid_mrr, test_mrr = evaluate(
device, g, edge_split, model, batch_size=1000
device, g, edge_split, model, batch_size=args.eval_batch_size
)
print(
f"Validation MRR {valid_mrr.item():.4f}, Test MRR {test_mrr.item():.4f}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment