Unverified Commit 1f931011 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the GATv2 example. (#6035)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 017d9d40
...@@ -89,7 +89,7 @@ def main(args): ...@@ -89,7 +89,7 @@ def main(args):
val_mask = g.ndata["val_mask"] val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
num_feats = features.shape[1] num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_classes
n_edges = g.num_edges() n_edges = g.num_edges()
print( print(
"""----Data statistics------' """----Data statistics------'
...@@ -138,7 +138,7 @@ def main(args): ...@@ -138,7 +138,7 @@ def main(args):
) )
# initialize graph # initialize graph
dur = [] mean = 0
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
if epoch >= 3: if epoch >= 3:
...@@ -152,29 +152,29 @@ def main(args): ...@@ -152,29 +152,29 @@ def main(args):
optimizer.step() optimizer.step()
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) mean = (mean * (epoch - 3) + (time.time() - t0)) / (epoch - 2)
train_acc = accuracy(logits[train_mask], labels[train_mask]) train_acc = accuracy(logits[train_mask], labels[train_mask])
if args.fastmode: if args.fastmode:
val_acc = accuracy(logits[val_mask], labels[val_mask]) val_acc = accuracy(logits[val_mask], labels[val_mask])
else: else:
val_acc = evaluate(g, model, features, labels, val_mask) val_acc = evaluate(g, model, features, labels, val_mask)
if args.early_stop: if args.early_stop:
if stopper.step(val_acc, model): if stopper.step(val_acc, model):
break break
print( print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format( " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, epoch,
np.mean(dur), mean,
loss.item(), loss.item(),
train_acc, train_acc,
val_acc, val_acc,
n_edges / np.mean(dur) / 1000, n_edges / mean / 1000,
)
) )
)
print() print()
if args.early_stop: if args.early_stop:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment