Commit ceb95303 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove removed function call

parent 82e02065
...@@ -434,7 +434,7 @@ class AlphaFold(nn.Module): ...@@ -434,7 +434,7 @@ class AlphaFold(nn.Module):
# Main recycling loop # Main recycling loop
num_iters = batch["aatype"].shape[-1] num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters): for cycle_no in range(num_iters):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
...@@ -443,7 +443,6 @@ class AlphaFold(nn.Module): ...@@ -443,7 +443,6 @@ class AlphaFold(nn.Module):
is_final_iter = cycle_no == (num_iters - 1) is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter: if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766) # Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
torch.clear_autocast_cache() torch.clear_autocast_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment