Commit f44e9830 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

reverse the early stop check back to originl now but the 'enabled must be a...

reverse the early stop check back to originl now but the 'enabled must be a bool (got Tensor)' error still persists
parent 0caffd57
......@@ -535,10 +535,7 @@ class AlphaFold(nn.Module):
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1) or early_stop
enable_grad= is_grad_enabled and is_final_iter
if (type(enable_grad)!=bool) and (type(enable_grad)==torch.Tensor):
enable_grad = enable_grad.item()
with torch.set_grad_enabled(enable_grad):
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment