"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "505468d19aa77ab5a61d92d9545b0071dd622cdf"
Commit 4bd43751 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add cache clearing to config

parent 263661a3
......@@ -23,9 +23,9 @@
"opt_level": "O2"
},
"zero_optimization": {
"stage": 1,
"cpu_offload": false,
"contiguous_gradients": false
"stage": 2,
"cpu_offload": true,
"contiguous_gradients": true
},
"activation_checkpointing": {
"partition_activations": true,
......
......@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 4,
"num_workers": 8,
},
},
},
......@@ -319,6 +319,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
},
......@@ -339,6 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
......
......@@ -270,8 +270,9 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt: int,
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False,
_clear_cache_btwn_extra_blocks: bool = True,
_: bool = True,
**kwargs,
):
"""
......@@ -309,8 +310,8 @@ class EvoformerStack(nn.Module):
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self._is_extra_msa_stack = _is_extra_msa_stack
self._clear_cache_btwn_extra_blocks = _clear_cache_btwn_extra_blocks
self.blocks = nn.ModuleList()
......@@ -373,8 +374,10 @@ class EvoformerStack(nn.Module):
)
for b in self.blocks
]
if(self._is_extra_msa_stack and self._clear_cache_btwn_extra_blocks):
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
print("hello!")
torch.cuda.empty_cache()
return block(*args)
......@@ -418,6 +421,7 @@ class ExtraMSAStack(nn.Module):
blocks_per_ckpt: int,
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
......@@ -440,6 +444,7 @@ class ExtraMSAStack(nn.Module):
blocks_per_ckpt=blocks_per_ckpt,
inf=inf,
eps=eps,
clear_cache_between_blocks=clear_cache_between_blocks,
_is_extra_msa_stack=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment