Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4c1cf6e2
"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "ae2c03a94db9a4948a17a4defaf679c519e13d4e"
Commit
4c1cf6e2
authored
Jun 12, 2022
by
Gustaf Ahdritz
Browse files
Add chunk size tuning
parent
8036a213
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
7 deletions
+112
-7
openfold/model/evoformer.py
openfold/model/evoformer.py
+39
-7
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+73
-0
No files found.
openfold/model/evoformer.py
View file @
4c1cf6e2
...
@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
,
ChunkSizeTuner
class
MSATransition
(
nn
.
Module
):
class
MSATransition
(
nn
.
Module
):
...
@@ -498,6 +498,7 @@ class EvoformerStack(nn.Module):
...
@@ -498,6 +498,7 @@ class EvoformerStack(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
tune_chunk_size
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -534,6 +535,8 @@ class EvoformerStack(nn.Module):
...
@@ -534,6 +535,8 @@ class EvoformerStack(nn.Module):
clear_cache_between_blocks:
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
"""
super
(
EvoformerStack
,
self
).
__init__
()
super
(
EvoformerStack
,
self
).
__init__
()
...
@@ -562,6 +565,11 @@ class EvoformerStack(nn.Module):
...
@@ -562,6 +565,11 @@ class EvoformerStack(nn.Module):
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
...
@@ -581,7 +589,9 @@ class EvoformerStack(nn.Module):
...
@@ -581,7 +589,9 @@ class EvoformerStack(nn.Module):
[*, N_seq, N_res] MSA mask
[*, N_seq, N_res] MSA mask
pair_mask:
pair_mask:
[*, N_res, N_res] pair mask
[*, N_res, N_res] pair mask
chunk_size: Inference-time subbatch size
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_lma: Whether to use low-memory attention during inference
Returns:
Returns:
m:
m:
...
@@ -590,7 +600,7 @@ class EvoformerStack(nn.Module):
...
@@ -590,7 +600,7 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
s:
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
"""
blocks
=
[
blocks
=
[
partial
(
partial
(
b
,
b
,
...
@@ -604,12 +614,20 @@ class EvoformerStack(nn.Module):
...
@@ -604,12 +614,20 @@ class EvoformerStack(nn.Module):
]
]
if
(
self
.
clear_cache_between_blocks
):
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
def
block_with_cache_clear
(
block
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
return
block
(
*
args
,
**
kwargs
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
blocks_per_ckpt
=
None
...
@@ -647,6 +665,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -647,6 +665,7 @@ class ExtraMSAStack(nn.Module):
ckpt
:
bool
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
tune_chunk_size
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
...
@@ -673,6 +692,11 @@ class ExtraMSAStack(nn.Module):
...
@@ -673,6 +692,11 @@ class ExtraMSAStack(nn.Module):
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
...
@@ -712,13 +736,21 @@ class ExtraMSAStack(nn.Module):
...
@@ -712,13 +736,21 @@ class ExtraMSAStack(nn.Module):
)
for
b
in
self
.
blocks
)
for
b
in
self
.
blocks
]
]
def
clear_cache
(
b
,
*
args
):
def
clear_cache
(
b
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
return
b
(
*
args
,
**
kwargs
)
if
(
self
.
clear_cache_between_blocks
):
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
for
b
in
blocks
:
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
...
...
openfold/utils/tensor_utils.py
View file @
4c1cf6e2
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
logging
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
...
@@ -417,3 +419,74 @@ def chunk_layer(
...
@@ -417,3 +419,74 @@ def chunk_layer(
out
=
tensor_tree_map
(
reshape
,
out
)
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
return
out
class
ChunkSizeTuner
:
def
__init__
(
self
,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size
=
256
,
):
self
.
max_chunk_size
=
max_chunk_size
self
.
cached_chunk_size
=
None
self
.
cached_arg_data
=
None
def
_determine_favorable_chunk_size
(
self
,
fn
,
args
,
min_chunk_size
):
logging
.
info
(
"Tuning chunk size..."
)
if
(
min_chunk_size
>=
self
.
max_chunk_size
):
return
min_chunk_size
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
min_chunk_size
]
+
candidates
def
test_chunk_size
(
chunk_size
):
try
:
with
torch
.
no_grad
():
fn
(
*
args
,
chunk_size
=
chunk_size
)
return
True
except
RuntimeError
:
return
False
min_viable_chunk_size_index
=
0
i
=
len
(
candidates
)
-
1
while
i
>
min_viable_chunk_size_index
:
viable
=
test_chunk_size
(
candidates
[
i
])
if
(
not
viable
):
i
=
(
min_viable_chunk_size_index
+
i
)
//
2
else
:
min_viable_chunk_size_index
=
i
i
=
(
i
+
len
(
candidates
)
-
1
)
//
2
return
candidates
[
min_viable_chunk_size_index
]
def
tune_chunk_size
(
self
,
representative_fn
:
Callable
,
args
:
Tuple
[
Any
],
min_chunk_size
:
int
,
)
->
int
:
consistent
=
True
arg_data
=
[
arg
if
type
(
arg
)
!=
torch
.
Tensor
else
arg
.
shape
for
arg
in
args
]
if
(
self
.
cached_arg_data
is
not
None
):
# If args have changed shape/value, we need to re-tune
assert
(
len
(
self
.
cached_arg_data
)
==
len
(
args
))
arg_data_iter
=
zip
(
self
.
cached_arg_data
,
arg_data
)
for
cached_arg_data
,
arg_data
in
arg_data_iter
:
assert
(
type
(
cached_arg_data
)
==
type
(
arg_data
))
consistent
=
cached_arg_data
==
arg_data
else
:
# Otherwise, we can reuse the precomputed value
consistent
=
False
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
args
,
min_chunk_size
,
)
self
.
cached_arg_data
=
arg_data
return
self
.
cached_chunk_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment