Unverified Commit 885b0eae authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the DGI example. (#6034)

parent 1f931011
...@@ -80,7 +80,7 @@ def main(args): ...@@ -80,7 +80,7 @@ def main(args):
cnt_wait = 0 cnt_wait = 0
best = 1e9 best = 1e9
best_t = 0 best_t = 0
dur = [] mean = 0
for epoch in range(args.n_dgi_epochs): for epoch in range(args.n_dgi_epochs):
dgi.train() dgi.train()
if epoch >= 3: if epoch >= 3:
...@@ -104,12 +104,12 @@ def main(args): ...@@ -104,12 +104,12 @@ def main(args):
break break
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
print( print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | " "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
"ETputs(KTEPS) {:.2f}".format( "ETputs(KTEPS) {:.2f}".format(
epoch, np.mean(dur), loss.item(), n_edges / np.mean(dur) / 1000 epoch, mean, loss.item(), n_edges / mean / 1000
) )
) )
...@@ -129,7 +129,7 @@ def main(args): ...@@ -129,7 +129,7 @@ def main(args):
dgi.load_state_dict(torch.load("best_dgi.pkl")) dgi.load_state_dict(torch.load("best_dgi.pkl"))
embeds = dgi.encoder(features, corrupt=False) embeds = dgi.encoder(features, corrupt=False)
embeds = embeds.detach() embeds = embeds.detach()
dur = [] mean = 0
for epoch in range(args.n_classifier_epochs): for epoch in range(args.n_classifier_epochs):
classifier.train() classifier.train()
if epoch >= 3: if epoch >= 3:
...@@ -142,17 +142,17 @@ def main(args): ...@@ -142,17 +142,17 @@ def main(args):
classifier_optimizer.step() classifier_optimizer.step()
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
acc = evaluate(classifier, embeds, labels, val_mask) acc = evaluate(classifier, embeds, labels, val_mask)
print( print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format( "ETputs(KTEPS) {:.2f}".format(
epoch, epoch,
np.mean(dur), mean,
loss.item(), loss.item(),
acc, acc,
n_edges / np.mean(dur) / 1000, n_edges / mean / 1000,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment