Commit 85c0a9a9 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Disable activation checkpointing during inference

parent 0333b7a3
...@@ -350,7 +350,7 @@ class EvoformerStack(nn.Module): ...@@ -350,7 +350,7 @@ class EvoformerStack(nn.Module):
) for b in self.blocks ) for b in self.blocks
], ],
args=(m, z), args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
s = None s = None
......
...@@ -265,7 +265,7 @@ class TemplatePairStack(nn.Module): ...@@ -265,7 +265,7 @@ class TemplatePairStack(nn.Module):
) for b in self.blocks ) for b in self.blocks
], ],
args=(t), args=(t),
blocks_per_ckpt=self.blocks_per_ckpt, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
t = self.layer_norm(t) t = self.layer_norm(t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment