Unverified Commit 562a1c87 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the CompGCN example. (#6068)

parent 2584f3af
import argparse
from time import time
import dgl.function as fn
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from data_loader import Data
from models import CompGCN_ConvE
......@@ -31,7 +27,7 @@ def predict(model, graph, device, data_iter, split="valid", mode="tail"):
pred = model(graph, sub, rel)
b_range = th.arange(pred.size()[0], device=device)
target_pred = pred[b_range, obj]
pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred)
pred = th.where(label.bool(), -th.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred
# compute metrics
......@@ -178,7 +174,6 @@ def main(args):
# validate
if val_results["mrr"] > best_mrr:
best_mrr = val_results["mrr"]
best_epoch = epoch
th.save(
compgcn_model.state_dict(), "comp_link" + "_" + args.dataset
)
......@@ -190,7 +185,7 @@ def main(args):
print("early stop.")
break
print(
"In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}".format(
"In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}, Train time: {}, Valid time: {}".format(
epoch, train_loss, val_results["mrr"], t1 - t0, t2 - t1
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment