Commit 263661a3 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Clear cuda cache between extra MSA blocks to alleviate fragmentation

parent 0ddf27a9
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
"opt_level": "O2" "opt_level": "O2"
}, },
"zero_optimization": { "zero_optimization": {
"stage": 1 "stage": 1,
"cpu_offload": false,
"contiguous_gradients": false
}, },
"activation_checkpointing": { "activation_checkpointing": {
"partition_activations": true, "partition_activations": true,
......
...@@ -271,6 +271,7 @@ class EvoformerStack(nn.Module): ...@@ -271,6 +271,7 @@ class EvoformerStack(nn.Module):
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
_clear_cache_btwn_extra_blocks: bool = True,
**kwargs, **kwargs,
): ):
""" """
...@@ -309,6 +310,7 @@ class EvoformerStack(nn.Module): ...@@ -309,6 +310,7 @@ class EvoformerStack(nn.Module):
self.blocks_per_ckpt = blocks_per_ckpt self.blocks_per_ckpt = blocks_per_ckpt
self._is_extra_msa_stack = _is_extra_msa_stack self._is_extra_msa_stack = _is_extra_msa_stack
self._clear_cache_btwn_extra_blocks = _clear_cache_btwn_extra_blocks
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
...@@ -361,17 +363,25 @@ class EvoformerStack(nn.Module): ...@@ -361,17 +363,25 @@ class EvoformerStack(nn.Module):
s: s:
[*, N_res, C_s] single embedding (or None if extra MSA stack) [*, N_res, C_s] single embedding (or None if extra MSA stack)
""" """
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self._is_extra_msa_stack and self._clear_cache_btwn_extra_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks( m, z = checkpoint_blocks(
blocks=[ blocks,
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(m, z), args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment