Unverified Commit 829c073f authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Benchmark] Fix bugs in the DGL link prediction example. (#6630)

parent 8394df7a
...@@ -40,6 +40,7 @@ main ...@@ -40,6 +40,7 @@ main
""" """
import argparse import argparse
import time
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
...@@ -348,6 +349,7 @@ def train( ...@@ -348,6 +349,7 @@ def train(
use_uva=use_uva, use_uva=use_uva,
) )
opt = torch.optim.Adam(model.parameters(), lr=args.lr) opt = torch.optim.Adam(model.parameters(), lr=args.lr)
start_epoch_time = time.time()
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
total_loss = 0 total_loss = 0
...@@ -378,7 +380,12 @@ def train( ...@@ -378,7 +380,12 @@ def train(
total_loss += loss.item() total_loss += loss.item()
if (it + 1) == args.early_stop: if (it + 1) == args.early_stop:
break break
print(f"Epoch {epoch:05d} | Loss {total_loss / (it + 1):.4f}") end_epoch_time = time.time()
print(
f"Epoch {epoch:05d} | "
f"Loss {total_loss / (it + 1):.4f} | "
f"Time {(end_epoch_time - start_epoch_time):.4f} s"
)
def parse_args(): def parse_args():
...@@ -450,8 +457,8 @@ def main(args): ...@@ -450,8 +457,8 @@ def main(args):
# Convert the graph to its bidirectional form. # Convert the graph to its bidirectional form.
g, reverse_eids = to_bidirected_with_reverse_mapping(g) g, reverse_eids = to_bidirected_with_reverse_mapping(g)
reverse_eids = reverse_eids.to(device) reverse_eids = reverse_eids.to(g.device)
seed_edges = torch.arange(g.num_edges()).to(device) seed_edges = torch.arange(g.num_edges()).to(g.device)
edge_split = dataset.get_edge_split() edge_split = dataset.get_edge_split()
# Create GraphSAGE model. # Create GraphSAGE model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment