Unverified Commit 562a1c87 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the CompGCN example. (#6068)

parent 2584f3af
import argparse import argparse
from time import time from time import time
import dgl.function as fn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from data_loader import Data from data_loader import Data
from models import CompGCN_ConvE from models import CompGCN_ConvE
...@@ -31,7 +27,7 @@ def predict(model, graph, device, data_iter, split="valid", mode="tail"): ...@@ -31,7 +27,7 @@ def predict(model, graph, device, data_iter, split="valid", mode="tail"):
pred = model(graph, sub, rel) pred = model(graph, sub, rel)
b_range = th.arange(pred.size()[0], device=device) b_range = th.arange(pred.size()[0], device=device)
target_pred = pred[b_range, obj] target_pred = pred[b_range, obj]
pred = th.where(label.byte(), -th.ones_like(pred) * 10000000, pred) pred = th.where(label.bool(), -th.ones_like(pred) * 10000000, pred)
pred[b_range, obj] = target_pred pred[b_range, obj] = target_pred
# compute metrics # compute metrics
...@@ -178,7 +174,6 @@ def main(args): ...@@ -178,7 +174,6 @@ def main(args):
# validate # validate
if val_results["mrr"] > best_mrr: if val_results["mrr"] > best_mrr:
best_mrr = val_results["mrr"] best_mrr = val_results["mrr"]
best_epoch = epoch
th.save( th.save(
compgcn_model.state_dict(), "comp_link" + "_" + args.dataset compgcn_model.state_dict(), "comp_link" + "_" + args.dataset
) )
...@@ -190,7 +185,7 @@ def main(args): ...@@ -190,7 +185,7 @@ def main(args):
print("early stop.") print("early stop.")
break break
print( print(
"In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}\n, Train time: {}, Valid time: {}".format( "In epoch {}, Train Loss: {:.4f}, Valid MRR: {:.5}, Train time: {}, Valid time: {}".format(
epoch, train_loss, val_results["mrr"], t1 - t0, t2 - t1 epoch, train_loss, val_results["mrr"], t1 - t0, t2 - t1
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment