Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4c1cf6e2
"examples/ScanNet/data.py" did not exist on "a3a079efc2ef1dcd83a4ba9dfa395e52506814ba"
Commit
4c1cf6e2
authored
Jun 12, 2022
by
Gustaf Ahdritz
Browse files
Add chunk size tuning
parent
8036a213
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
7 deletions
+112
-7
openfold/model/evoformer.py
openfold/model/evoformer.py
+39
-7
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+73
-0
No files found.
openfold/model/evoformer.py
View file @
4c1cf6e2
...
...
@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
,
ChunkSizeTuner
class
MSATransition
(
nn
.
Module
):
...
...
@@ -498,6 +498,7 @@ class EvoformerStack(nn.Module):
inf
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
tune_chunk_size
:
bool
=
True
,
**
kwargs
,
):
"""
...
...
@@ -534,6 +535,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
super
(
EvoformerStack
,
self
).
__init__
()
...
...
@@ -562,6 +565,11 @@ class EvoformerStack(nn.Module):
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
...
...
@@ -581,7 +589,9 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
...
...
@@ -590,7 +600,7 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
"""
blocks
=
[
partial
(
b
,
...
...
@@ -604,12 +614,20 @@ class EvoformerStack(nn.Module):
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
def
block_with_cache_clear
(
block
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
return
block
(
*
args
,
**
kwargs
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
...
...
@@ -647,6 +665,7 @@ class ExtraMSAStack(nn.Module):
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
tune_chunk_size
:
bool
=
True
,
**
kwargs
,
):
super
(
ExtraMSAStack
,
self
).
__init__
()
...
...
@@ -673,6 +692,11 @@ class ExtraMSAStack(nn.Module):
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
)
self
.
blocks
.
append
(
block
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
...
...
@@ -712,13 +736,21 @@ class ExtraMSAStack(nn.Module):
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
):
def
clear_cache
(
b
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
return
b
(
*
args
,
**
kwargs
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
...
...
openfold/utils/tensor_utils.py
View file @
4c1cf6e2
...
...
@@ -14,6 +14,8 @@
# limitations under the License.
from
functools
import
partial
import
logging
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
...
...
@@ -417,3 +419,74 @@ def chunk_layer(
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
class
ChunkSizeTuner
:
def
__init__
(
self
,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size
=
256
,
):
self
.
max_chunk_size
=
max_chunk_size
self
.
cached_chunk_size
=
None
self
.
cached_arg_data
=
None
def
_determine_favorable_chunk_size
(
self
,
fn
,
args
,
min_chunk_size
):
logging
.
info
(
"Tuning chunk size..."
)
if
(
min_chunk_size
>=
self
.
max_chunk_size
):
return
min_chunk_size
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
min_chunk_size
]
+
candidates
def
test_chunk_size
(
chunk_size
):
try
:
with
torch
.
no_grad
():
fn
(
*
args
,
chunk_size
=
chunk_size
)
return
True
except
RuntimeError
:
return
False
min_viable_chunk_size_index
=
0
i
=
len
(
candidates
)
-
1
while
i
>
min_viable_chunk_size_index
:
viable
=
test_chunk_size
(
candidates
[
i
])
if
(
not
viable
):
i
=
(
min_viable_chunk_size_index
+
i
)
//
2
else
:
min_viable_chunk_size_index
=
i
i
=
(
i
+
len
(
candidates
)
-
1
)
//
2
return
candidates
[
min_viable_chunk_size_index
]
def
tune_chunk_size
(
self
,
representative_fn
:
Callable
,
args
:
Tuple
[
Any
],
min_chunk_size
:
int
,
)
->
int
:
consistent
=
True
arg_data
=
[
arg
if
type
(
arg
)
!=
torch
.
Tensor
else
arg
.
shape
for
arg
in
args
]
if
(
self
.
cached_arg_data
is
not
None
):
# If args have changed shape/value, we need to re-tune
assert
(
len
(
self
.
cached_arg_data
)
==
len
(
args
))
arg_data_iter
=
zip
(
self
.
cached_arg_data
,
arg_data
)
for
cached_arg_data
,
arg_data
in
arg_data_iter
:
assert
(
type
(
cached_arg_data
)
==
type
(
arg_data
))
consistent
=
cached_arg_data
==
arg_data
else
:
# Otherwise, we can reuse the precomputed value
consistent
=
False
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
args
,
min_chunk_size
,
)
self
.
cached_arg_data
=
arg_data
return
self
.
cached_chunk_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment