Commit 71057ac7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove redundant parameter, add documentation for cache clearing

parent 4bd43751
...@@ -272,7 +272,6 @@ class EvoformerStack(nn.Module): ...@@ -272,7 +272,6 @@ class EvoformerStack(nn.Module):
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
_: bool = True,
**kwargs, **kwargs,
): ):
""" """
...@@ -306,6 +305,9 @@ class EvoformerStack(nn.Module): ...@@ -306,6 +305,9 @@ class EvoformerStack(nn.Module):
Dropout used for pair activations Dropout used for pair activations
blocks_per_ckpt: blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
""" """
super(EvoformerStack, self).__init__() super(EvoformerStack, self).__init__()
...@@ -377,7 +379,6 @@ class EvoformerStack(nn.Module): ...@@ -377,7 +379,6 @@ class EvoformerStack(nn.Module):
if(self.clear_cache_between_blocks): if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args): def block_with_cache_clear(block, *args):
print("hello!")
torch.cuda.empty_cache() torch.cuda.empty_cache()
return block(*args) return block(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment