Unverified Commit ae7e3db6 authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

fix bugs in gat training script (#3505)


Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
parent 3f792071
...@@ -107,6 +107,7 @@ def main(args): ...@@ -107,6 +107,7 @@ def main(args):
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
if epoch >= 3: if epoch >= 3:
torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features) logits = model(features)
...@@ -117,6 +118,7 @@ def main(args): ...@@ -117,6 +118,7 @@ def main(args):
optimizer.step() optimizer.step()
if epoch >= 3: if epoch >= 3:
torch.cuda.synchronize()
dur.append(time.time() - t0) dur.append(time.time() - t0)
train_acc = accuracy(logits[train_mask], labels[train_mask]) train_acc = accuracy(logits[train_mask], labels[train_mask])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment