Unverified Commit 6de7d5fa authored by Hao Yuan's avatar Hao Yuan Committed by GitHub
Browse files

[Bugfix] Fix conditional judgment of cuda (#4026) (#4027)

parent 4b87e47f
...@@ -107,7 +107,8 @@ def main(args): ...@@ -107,7 +107,8 @@ def main(args):
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
if epoch >= 3: if epoch >= 3:
torch.cuda.synchronize() if cuda:
torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features) logits = model(features)
...@@ -118,7 +119,8 @@ def main(args): ...@@ -118,7 +119,8 @@ def main(args):
optimizer.step() optimizer.step()
if epoch >= 3: if epoch >= 3:
torch.cuda.synchronize() if cuda:
torch.cuda.synchronize()
dur.append(time.time() - t0) dur.append(time.time() - t0)
train_acc = accuracy(logits[train_mask], labels[train_mask]) train_acc = accuracy(logits[train_mask], labels[train_mask])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment