Unverified Commit 93074dd3 authored by Roger Tu's avatar Roger Tu Committed by GitHub
Browse files

Bug fix for RGCN Link Prediction Example Exceeding Specified Epochs (#2757)

* Address bug where epochs exceed specified limit

* fix spacing
parent d3a05822
...@@ -191,13 +191,13 @@ def main(args): ...@@ -191,13 +191,13 @@ def main(args):
valid_data, test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size, valid_data, test_data, hits=[1, 3, 10], eval_bz=args.eval_batch_size,
eval_p=args.eval_protocol) eval_p=args.eval_protocol)
# save best model # save best model
if mrr < best_mrr: if best_mrr < mrr:
best_mrr = mrr
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file)
if epoch >= args.n_epochs: if epoch >= args.n_epochs:
break break
else:
best_mrr = mrr
torch.save({'state_dict': model.state_dict(), 'epoch': epoch},
model_state_file)
if use_cuda: if use_cuda:
model.cuda() model.cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment