Commit 2f2793f4 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

added an extra step because sometimes the boolean value will change into a pytorch tensor

parent 566ca1a3
......@@ -535,7 +535,10 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1) or early_stop
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
enable_grad= is_grad_enabled and is_final_iter
if (type(enable_grad)!=bool) and (type(enable_grad)==torch.Tensor):
enable_grad = enable_grad.item()
with torch.set_grad_enabled(enable_grad):
if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment