Commit 4bd43751 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add cache clearing to config

parent 263661a3
...@@ -23,9 +23,9 @@ ...@@ -23,9 +23,9 @@
"opt_level": "O2" "opt_level": "O2"
}, },
"zero_optimization": { "zero_optimization": {
"stage": 1, "stage": 2,
"cpu_offload": false, "cpu_offload": true,
"contiguous_gradients": false "contiguous_gradients": true
}, },
"activation_checkpointing": { "activation_checkpointing": {
"partition_activations": true, "partition_activations": true,
......
...@@ -226,7 +226,7 @@ config = mlc.ConfigDict( ...@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 4, "num_workers": 8,
}, },
}, },
}, },
...@@ -319,6 +319,7 @@ config = mlc.ConfigDict( ...@@ -319,6 +319,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": True,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
...@@ -339,6 +340,7 @@ config = mlc.ConfigDict( ...@@ -339,6 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9, "inf": 1e9,
"eps": eps, # 1e-10, "eps": eps, # 1e-10,
}, },
......
...@@ -270,8 +270,9 @@ class EvoformerStack(nn.Module): ...@@ -270,8 +270,9 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt: int, blocks_per_ckpt: int,
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
_clear_cache_btwn_extra_blocks: bool = True, _: bool = True,
**kwargs, **kwargs,
): ):
""" """
...@@ -309,8 +310,8 @@ class EvoformerStack(nn.Module): ...@@ -309,8 +310,8 @@ class EvoformerStack(nn.Module):
super(EvoformerStack, self).__init__() super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self._is_extra_msa_stack = _is_extra_msa_stack self._is_extra_msa_stack = _is_extra_msa_stack
self._clear_cache_btwn_extra_blocks = _clear_cache_btwn_extra_blocks
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
...@@ -373,8 +374,10 @@ class EvoformerStack(nn.Module): ...@@ -373,8 +374,10 @@ class EvoformerStack(nn.Module):
) )
for b in self.blocks for b in self.blocks
] ]
if(self._is_extra_msa_stack and self._clear_cache_btwn_extra_blocks):
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args): def block_with_cache_clear(block, *args):
print("hello!")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return block(*args) return block(*args)
...@@ -418,6 +421,7 @@ class ExtraMSAStack(nn.Module): ...@@ -418,6 +421,7 @@ class ExtraMSAStack(nn.Module):
blocks_per_ckpt: int, blocks_per_ckpt: int,
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False,
**kwargs, **kwargs,
): ):
super(ExtraMSAStack, self).__init__() super(ExtraMSAStack, self).__init__()
...@@ -440,6 +444,7 @@ class ExtraMSAStack(nn.Module): ...@@ -440,6 +444,7 @@ class ExtraMSAStack(nn.Module):
blocks_per_ckpt=blocks_per_ckpt, blocks_per_ckpt=blocks_per_ckpt,
inf=inf, inf=inf,
eps=eps, eps=eps,
clear_cache_between_blocks=clear_cache_between_blocks,
_is_extra_msa_stack=True, _is_extra_msa_stack=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment