Commit 8036a213 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove redundant checkpoint functions

parent 5ada323e
......@@ -375,24 +375,6 @@ class AlphaFold(nn.Module):
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch):
"""
Args:
......@@ -448,9 +430,7 @@ class AlphaFold(nn.Module):
m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment