Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4bd43751
Commit
4bd43751
authored
Nov 19, 2021
by
Gustaf Ahdritz
Browse files
Add cache clearing to config
parent
263661a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
7 deletions
+14
-7
deepspeed_config.json
deepspeed_config.json
+3
-3
openfold/config.py
openfold/config.py
+3
-1
openfold/model/evoformer.py
openfold/model/evoformer.py
+8
-3
No files found.
deepspeed_config.json
View file @
4bd43751
...
@@ -23,9 +23,9 @@
...
@@ -23,9 +23,9 @@
"opt_level"
:
"O2"
"opt_level"
:
"O2"
},
},
"zero_optimization"
:
{
"zero_optimization"
:
{
"stage"
:
1
,
"stage"
:
2
,
"cpu_offload"
:
fals
e
,
"cpu_offload"
:
tru
e
,
"contiguous_gradients"
:
fals
e
"contiguous_gradients"
:
tru
e
},
},
"activation_checkpointing"
:
{
"activation_checkpointing"
:
{
"partition_activations"
:
true
,
"partition_activations"
:
true
,
...
...
openfold/config.py
View file @
4bd43751
...
@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
...
@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd"
:
False
,
"use_small_bfd"
:
False
,
"data_loaders"
:
{
"data_loaders"
:
{
"batch_size"
:
1
,
"batch_size"
:
1
,
"num_workers"
:
4
,
"num_workers"
:
8
,
},
},
},
},
},
},
...
@@ -319,6 +319,7 @@ config = mlc.ConfigDict(
...
@@ -319,6 +319,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
...
@@ -339,6 +340,7 @@ config = mlc.ConfigDict(
...
@@ -339,6 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
...
...
openfold/model/evoformer.py
View file @
4bd43751
...
@@ -270,8 +270,9 @@ class EvoformerStack(nn.Module):
...
@@ -270,8 +270,9 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
_
clear_cache_btwn_extra_blocks
:
bool
=
True
,
_
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -309,8 +310,8 @@ class EvoformerStack(nn.Module):
...
@@ -309,8 +310,8 @@ class EvoformerStack(nn.Module):
super
(
EvoformerStack
,
self
).
__init__
()
super
(
EvoformerStack
,
self
).
__init__
()
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
_is_extra_msa_stack
=
_is_extra_msa_stack
self
.
_is_extra_msa_stack
=
_is_extra_msa_stack
self
.
_clear_cache_btwn_extra_blocks
=
_clear_cache_btwn_extra_blocks
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
...
@@ -373,8 +374,10 @@ class EvoformerStack(nn.Module):
...
@@ -373,8 +374,10 @@ class EvoformerStack(nn.Module):
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
]
]
if
(
self
.
_is_extra_msa_stack
and
self
.
_clear_cache_btwn_extra_blocks
):
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
def
block_with_cache_clear
(
block
,
*
args
):
print
(
"hello!"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
return
block
(
*
args
)
...
@@ -418,6 +421,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -418,6 +421,7 @@ class ExtraMSAStack(nn.Module):
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
...
@@ -440,6 +444,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -440,6 +444,7 @@ class ExtraMSAStack(nn.Module):
blocks_per_ckpt
=
blocks_per_ckpt
,
blocks_per_ckpt
=
blocks_per_ckpt
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
clear_cache_between_blocks
=
clear_cache_between_blocks
,
_is_extra_msa_stack
=
True
,
_is_extra_msa_stack
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment