Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d6b36a80
Commit
d6b36a80
authored
Jun 15, 2022
by
Gustaf Ahdritz
Browse files
Tweak chunk tuning a little
parent
29f2ffe0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
28 deletions
+81
-28
openfold/model/evoformer.py
openfold/model/evoformer.py
+53
-17
openfold/model/template.py
openfold/model/template.py
+28
-11
No files found.
openfold/model/evoformer.py
View file @
d6b36a80
...
...
@@ -185,12 +185,16 @@ class EvoformerBlockCore(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
# Need to dodge activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
...
...
@@ -240,7 +244,7 @@ class EvoformerBlockCore(nn.Module):
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
_attn_
chunk_size
,
use_lma
=
use_lma
)
),
...
...
@@ -251,7 +255,7 @@ class EvoformerBlockCore(nn.Module):
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
_attn_
chunk_size
,
use_lma
=
use_lma
,
)
),
...
...
@@ -324,21 +328,33 @@ class EvoformerBlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
)
),
inplace
=
inplace_safe
,
)
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
),
inplace
=
inplace_safe
,
)
m
,
z
=
self
.
core
(
m
,
...
...
@@ -348,6 +364,7 @@ class EvoformerBlock(nn.Module):
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
)
return
m
,
z
...
...
@@ -421,7 +438,11 @@ class ExtraMSABlock(nn.Module):
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
# If function calls could speak...
m
=
add
(
m
,
self
.
msa_dropout_layer
(
...
...
@@ -429,7 +450,7 @@ class ExtraMSABlock(nn.Module):
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
_attn_
chunk_size
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
_chunk_logits
=
...
...
@@ -459,6 +480,7 @@ class ExtraMSABlock(nn.Module):
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
)
return
m
,
z
...
...
@@ -621,12 +643,19 @@ class EvoformerStack(nn.Module):
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
2
),
)
for
b
in
blocks
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
...
...
@@ -744,12 +773,19 @@ class ExtraMSAStack(nn.Module):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
2
),
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
...
...
openfold/model/template.py
View file @
d6b36a80
...
...
@@ -41,6 +41,7 @@ from openfold.utils.feats import (
from
openfold.utils.tensor_utils
import
(
add
,
chunk_layer
,
ChunkSizeTuner
,
permute_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
...
...
@@ -293,6 +294,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n
,
dropout_rate
,
blocks_per_ckpt
,
tune_chunk_size
:
bool
=
False
,
inf
=
1e9
,
**
kwargs
,
):
...
...
@@ -333,6 +335,11 @@ class TemplatePairStack(nn.Module):
self
.
layer_norm
=
LayerNorm
(
c_t
)
self
.
tune_chunk_size
=
tune_chunk_size
self
.
chunk_size_tuner
=
None
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
t
:
torch
.
tensor
,
...
...
@@ -355,18 +362,28 @@ class TemplatePairStack(nn.Module):
expand_idx
[
-
3
]
=
t
.
shape
[
-
4
]
mask
=
mask
.
expand
(
*
expand_idx
)
blocks
=
[
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
for
b
in
self
.
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
,),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
chunk_size
)
for
b
in
blocks
]
t
,
=
checkpoint_blocks
(
blocks
=
[
partial
(
b
,
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
for
b
in
self
.
blocks
],
blocks
=
blocks
,
args
=
(
t
,),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment