"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "fa71fb447803d269d1050fe5081e2cb577b04b94"
Unverified Commit 158b0fcd authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the citation_network example. (#6032)

parent 61b2f4e2
...@@ -61,7 +61,7 @@ def main(args): ...@@ -61,7 +61,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.num_edges() n_edges = g.num_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
...@@ -108,7 +108,7 @@ def main(args): ...@@ -108,7 +108,7 @@ def main(args):
) )
# initialize graph # initialize graph
dur = [] mean = 0
for epoch in range(200): for epoch in range(200):
model.train() model.train()
if epoch >= 3: if epoch >= 3:
...@@ -122,19 +122,18 @@ def main(args): ...@@ -122,19 +122,18 @@ def main(args):
optimizer.step() optimizer.step()
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
acc = evaluate(model, features, labels, val_mask)
acc = evaluate(model, features, labels, val_mask) print(
print( "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " "ETputs(KTEPS) {:.2f}".format(
"ETputs(KTEPS) {:.2f}".format( epoch,
epoch, mean,
np.mean(dur), loss.item(),
loss.item(), acc,
acc, n_edges / mean / 1000,
n_edges / np.mean(dur) / 1000, )
) )
)
print() print()
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment