Unverified Commit 829c073f authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Benchmark] Fix bugs in the DGL link prediction example. (#6630)

parent 8394df7a
......@@ -40,6 +40,7 @@ main
"""
import argparse
import time
import dgl
import dgl.nn as dglnn
......@@ -348,6 +349,7 @@ def train(
use_uva=use_uva,
)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
start_epoch_time = time.time()
for epoch in range(args.epochs):
model.train()
total_loss = 0
......@@ -378,7 +380,12 @@ def train(
total_loss += loss.item()
if (it + 1) == args.early_stop:
break
print(f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f}")
end_epoch_time = time.time()
print(
f"Epoch {epoch:05d} | "
f"Loss {total_loss / (it + 1):.4f} | "
f"Time {(end_epoch_time - start_epoch_time):.4f} s"
)
def parse_args():
......@@ -450,8 +457,8 @@ def main(args):
# Convert the graph to its bidirectional form.
g, reverse_eids = to_bidirected_with_reverse_mapping(g)
reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(g.num_edges()).to(device)
reverse_eids = reverse_eids.to(g.device)
seed_edges = torch.arange(g.num_edges()).to(g.device)
edge_split = dataset.get_edge_split()
# Create GraphSAGE model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment