Commit 67e34d92 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Avoid calling the checkpointing function with grad disabled

parent c290ff87
...@@ -571,10 +571,14 @@ class EvoformerStack(nn.Module): ...@@ -571,10 +571,14 @@ class EvoformerStack(nn.Module):
blocks = [partial(block_with_cache_clear, b) for b in blocks] blocks = [partial(block_with_cache_clear, b) for b in blocks]
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
m, z = checkpoint_blocks( m, z = checkpoint_blocks(
blocks, blocks,
args=(m, z), args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=blocks_per_ckpt,
) )
s = self.linear(m[..., 0, :, :]) s = self.linear(m[..., 0, :, :])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment