"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6ea83608adcf9302d3d24733dc319ab4ea9607ad"
Unverified Commit 87a41e50 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the MoNet example. (#5999)


Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent cf5c1930
......@@ -78,7 +78,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.num_edges()
print(
"""----Data statistics------'
......@@ -127,7 +127,7 @@ def main(args):
)
# initialize graph
dur = []
mean = 0
for epoch in range(args.n_epochs):
model.train()
if epoch >= 3:
......@@ -141,19 +141,18 @@ def main(args):
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(model, features, pseudo, labels, val_mask)
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
acc = evaluate(model, features, pseudo, labels, val_mask)
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
mean,
loss.item(),
acc,
n_edges / mean / 1000,
)
)
)
print()
acc = evaluate(model, features, pseudo, labels, test_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment