Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6e66b218
Commit
6e66b218
authored
Jun 10, 2022
by
Gustaf Ahdritz
Browse files
Vastly lower peak inference memory usage
parent
ec5619fc
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
970 additions
and
218 deletions
+970
-218
openfold/config.py
openfold/config.py
+44
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+7
-1
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+3
-2
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+1
-0
openfold/model/embedders.py
openfold/model/embedders.py
+17
-12
openfold/model/evoformer.py
openfold/model/evoformer.py
+89
-51
openfold/model/model.py
openfold/model/model.py
+100
-67
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+22
-6
openfold/model/template.py
openfold/model/template.py
+287
-32
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+322
-35
openfold/np/protein.py
openfold/np/protein.py
+8
-0
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+8
-0
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+2
-1
openfold/utils/loss.py
openfold/utils/loss.py
+11
-5
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+11
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+1
-1
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+1
-1
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+36
-3
No files found.
openfold/config.py
View file @
6e66b218
...
...
@@ -10,6 +10,29 @@ def set_inf(c, inf):
c
[
k
]
=
inf
def
enforce_config_constraints
(
config
):
def
string_to_setting
(
s
):
path
=
s
.
split
(
'.'
)
setting
=
config
for
p
in
path
:
setting
=
setting
[
p
]
return
setting
mutually_exclusive_bools
=
[
(
"model.template.average_templates"
,
"model.template.offload_templates"
)
]
for
s1
,
s2
in
mutually_exclusive_bools
:
s1_setting
=
string_to_setting
(
s1
)
s2_setting
=
string_to_setting
(
s2
)
if
(
s1_setting
and
s2_setting
):
raise
ValueError
(
f
"Only one of
{
s1
}
and
{
s2
}
may be set at a time"
)
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
if
name
==
"initial_training"
:
...
...
@@ -22,6 +45,14 @@ def model_config(name, train=False, low_prec=False):
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
elif
name
==
"finetuning_ptm"
:
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"model_1"
:
# AF2 Suppl. Table 5, Model 1.1.1
c
.
data
.
train
.
max_extra_msa
=
5120
...
...
@@ -95,6 +126,8 @@ def model_config(name, train=False, low_prec=False):
# a global constant
set_inf
(
c
,
1e4
)
enforce_config_constraints
(
c
)
return
c
...
...
@@ -346,6 +379,16 @@ config = mlc.ConfigDict(
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
False
,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates"
:
False
,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates.
"offload_templates"
:
False
,
},
"extra_msa"
:
{
"extra_msa_embedder"
:
{
...
...
@@ -498,7 +541,7 @@ config = mlc.ConfigDict(
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.
0
,
"weight"
:
0.
,
"enabled"
:
tm_enabled
,
},
"eps"
:
eps
,
...
...
openfold/data/data_modules.py
View file @
6e66b218
...
...
@@ -625,14 +625,21 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
train_chain_data_cache_path
,
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
generator
=
generator
,
_roll_at_init
=
False
,
)
if
(
self
.
val_data_dir
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
...
...
@@ -660,7 +667,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset
=
None
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
...
...
openfold/data/data_pipeline.py
View file @
6e66b218
...
...
@@ -97,6 +97,7 @@ def unify_template_features(
chain_indices
=
np
.
array
(
n_templates
*
[
i
])
out_dict
[
"template_chain_index"
]
=
chain_indices
if
(
n_templates
!=
0
):
out_dicts
.
append
(
out_dict
)
out_dict
=
{
...
...
@@ -741,7 +742,7 @@ class DataPipeline:
)
->
FeatureDict
:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter
. No templates
.
hack from Twitter
(a.k.a. AlphaFold-Gap)
.
"""
with
open
(
fasta_path
,
'r'
)
as
f
:
fasta_str
=
f
.
read
()
...
...
openfold/data/data_transforms.py
View file @
6e66b218
...
...
@@ -728,6 +728,7 @@ def make_atom14_positions(protein):
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.0
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
...
...
openfold/model/embedders.py
View file @
6e66b218
...
...
@@ -15,10 +15,10 @@
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
one_hot
from
openfold.utils.tensor_utils
import
add
,
one_hot
class
InputEmbedder
(
nn
.
Module
):
...
...
@@ -132,7 +132,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
"""
def
__init__
(
self
,
c_m
:
int
,
...
...
@@ -174,6 +173,7 @@ class RecyclingEmbedder(nn.Module):
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
_inplace
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
...
...
@@ -189,6 +189,19 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
if
(
_inplace
):
m
.
copy_
(
m_update
)
m_update
=
m
# [*, N, N, C_z]
z_update
=
self
.
layer_norm_z
(
z
)
if
(
_inplace
):
z
.
copy_
(
z_update
)
z_update
=
z
# This squared method might become problematic in FP16 mode.
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
max_bin
,
...
...
@@ -197,13 +210,6 @@ class RecyclingEmbedder(nn.Module):
device
=
x
.
device
,
requires_grad
=
False
,
)
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins
=
bins
**
2
upper
=
torch
.
cat
(
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
...
...
@@ -217,7 +223,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
d
=
self
.
linear
(
d
)
z_update
=
d
+
self
.
layer_norm_z
(
z
)
z_update
=
add
(
z_update
,
d
,
_inplace
)
return
m_update
,
z_update
...
...
@@ -315,7 +321,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
"""
def
__init__
(
self
,
c_in
:
int
,
...
...
openfold/model/evoformer.py
View file @
6e66b218
...
...
@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
class
MSATransition
(
nn
.
Module
):
...
...
@@ -192,32 +192,76 @@ class EvoformerBlockCore(nn.Module):
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_transition
(
# Need to dodge activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
z
=
add
(
z
,
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_inplace
=
inplace_safe
),
inplace
=
inplace_safe
,
)
tmu_update
=
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
,
_inplace
=
inplace_safe
,
_add_with_inplace
=
True
,
)
z
=
z
+
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
,
_inplace
=
inplace_safe
,
_add_with_inplace
=
True
,
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
),
inplace
=
inplace_safe
,
)
z
=
z
+
self
.
ps_dropout_col_layer
(
z
=
add
(
z
,
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
),
inplace
=
inplace_safe
,
)
z
=
z
+
self
.
pair_transition
(
z
=
add
(
z
,
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
return
m
,
z
...
...
@@ -378,17 +422,9 @@ class ExtraMSABlock(nn.Module):
_chunk_logits
:
Optional
[
int
]
=
1024
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
add
(
m1
,
m2
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
torch
.
is_grad_enabled
()):
m1
=
m1
+
m2
else
:
m1
+=
m2
return
m1
m
=
add
(
m
,
self
.
msa_dropout_layer
(
# If function calls could speak...
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
...
...
@@ -396,21 +432,24 @@ class ExtraMSABlock(nn.Module):
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
))
),
inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
def
fn
(
m
,
z
):
m
=
add
(
m
,
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
),
inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
m
,
z
=
self
.
core
(
m
,
...
...
@@ -590,7 +629,6 @@ class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
...
...
openfold/model/model.py
View file @
6e66b218
...
...
@@ -12,18 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
weakref
import
torch
import
torch.nn
as
nn
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
atom14_to_atom37
,
)
from
openfold.model.embedders
import
(
InputEmbedder
,
RecyclingEmbedder
,
...
...
@@ -33,16 +27,26 @@ from openfold.model.embedders import (
)
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.heads
import
AuxiliaryHeads
import
openfold.np.residue_constants
as
residue_constants
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
embed_templates_average
,
embed_templates_offload
,
)
import
openfold.np.residue_constants
as
residue_constants
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
atom14_to_atom37
,
)
from
openfold.utils.loss
import
(
compute_plddt
,
)
from
openfold.utils.tensor_utils
import
(
add
,
dict_multimap
,
tensor_tree_map
,
)
...
...
@@ -64,52 +68,71 @@ class AlphaFold(nn.Module):
super
(
AlphaFold
,
self
).
__init__
()
self
.
globals
=
config
.
globals
config
=
config
.
model
template_config
=
config
.
template
extra_msa_config
=
config
.
extra_msa
self
.
config
=
config
.
model
self
.
template_config
=
self
.
config
.
template
self
.
extra_msa_config
=
self
.
config
.
extra_msa
# Main trunk + structure module
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
self
.
config
[
"input_embedder"
],
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
**
self
.
config
[
"recycling_embedder"
],
)
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
template_config
[
"template_angle_embedder"
],
**
self
.
template_config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
template_config
[
"template_pair_embedder"
],
**
self
.
template_config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
template_config
[
"template_pair_stack"
],
**
self
.
template_config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
template_config
[
"template_pointwise_attention"
],
**
self
.
template_config
[
"template_pointwise_attention"
],
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
**
self
.
extra_msa_config
[
"extra_msa_embedder"
],
)
self
.
extra_msa_stack
=
ExtraMSAStack
(
**
extra_msa_config
[
"extra_msa_stack"
],
**
self
.
extra_msa_config
[
"extra_msa_stack"
],
)
self
.
evoformer
=
EvoformerStack
(
**
config
[
"evoformer_stack"
],
**
self
.
config
[
"evoformer_stack"
],
)
self
.
structure_module
=
StructureModule
(
**
config
[
"structure_module"
],
**
self
.
config
[
"structure_module"
],
)
self
.
aux_heads
=
AuxiliaryHeads
(
config
[
"heads"
],
self
.
config
[
"heads"
],
)
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
)
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
)
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
if
(
inplace_safe
):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair
=
z
.
new_zeros
(
z
.
shape
[:
-
3
]
+
(
n_templ
,
n
,
n
,
self
.
globals
.
c_t
)
)
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
...
...
@@ -117,18 +140,7 @@ class AlphaFold(nn.Module):
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
template
.
use_unit_vector
,
...
...
@@ -138,23 +150,27 @@ class AlphaFold(nn.Module):
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
if
(
inplace_safe
):
t_pair
[...,
i
,
:,
:,
:]
=
t
else
:
pair_embeds
.
append
(
t
)
template_embeds
.
append
(
single_template_embeds
)
del
t
tem
pla
t
e_
embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
)
,
template_embeds
,
)
if
(
not
in
pla
c
e_
safe
):
t_pair
=
torch
.
cat
(
pair_embeds
,
dim
=
templ_dim
)
del
pair_embeds
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
t
emplate_embeds
[
"
pair
"
]
,
t
_
pair
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
t_pair
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
...
...
@@ -164,17 +180,28 @@ class AlphaFold(nn.Module):
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
if
(
inplace_safe
):
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
template_angle_feat
=
build_template_angle_feat
(
batch
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_
prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prev
s
,
_recycle
=
True
):
# Primary output dictionary
outputs
=
{}
...
...
@@ -190,13 +217,14 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
seq_mask
=
feats
[
"seq_mask"
]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
]
# Initialize the MSA and pair representations
#
#
Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
...
...
@@ -206,6 +234,10 @@ class AlphaFold(nn.Module):
feats
[
"msa_feat"
],
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function.
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
...
...
@@ -236,24 +268,16 @@ class AlphaFold(nn.Module):
m_1_prev
,
z_prev
,
x_prev
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
# [*, N, N, C_z]
z
+=
z_prev_emb
#
Possibly prevents memory fragmentation
#
This matters during inference with large N
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
...
...
@@ -269,7 +293,10 @@ class AlphaFold(nn.Module):
)
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
add
(
z
,
template_embeds
.
pop
(
"template_pair_embedding"
),
inplace_safe
,
)
if
self
.
config
.
template
.
embed_angles
:
# [*, S = S_c + S_t, N, C_m]
...
...
@@ -301,6 +328,8 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
a
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
...
...
@@ -416,6 +445,7 @@ class AlphaFold(nn.Module):
"""
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
...
...
@@ -440,12 +470,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
prevs
,
_recycle
=
(
num_iters
>
1
)
)
if
(
not
is_final_iter
):
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/outer_product_mean.py
View file @
6e66b218
...
...
@@ -82,7 +82,13 @@ class OuterProductMean(nn.Module):
no_batch_dims
=
1
,
)
out
.
append
(
outer
)
# For some cursed reason making this distinction saves memory
if
(
len
(
out
)
==
1
):
outer
=
out
[
0
].
unsqueeze
(
0
)
else
:
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
...
...
@@ -90,7 +96,8 @@ class OuterProductMean(nn.Module):
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
_inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -105,12 +112,17 @@ class OuterProductMean(nn.Module):
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
ln
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
b
=
self
.
linear_2
(
m
)
*
mask
a
=
self
.
linear_1
(
ln
)
a
=
a
*
mask
b
=
self
.
linear_2
(
ln
)
b
=
b
*
mask
del
ln
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
...
...
@@ -122,8 +134,12 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
outer
=
outer
/
(
self
.
eps
+
norm
)
if
(
_inplace
):
outer
/=
norm
else
:
outer
=
outer
/
norm
return
outer
openfold/model/template.py
View file @
6e66b218
...
...
@@ -34,10 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.feats
import
(
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.utils.tensor_utils
import
(
add
,
chunk_layer
,
permute_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
)
...
...
@@ -191,7 +197,8 @@ class TemplatePairStackBlock(nn.Module):
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
_mask_trans
:
bool
=
True
,
_inplace
:
bool
=
False
,
):
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
...
...
@@ -203,42 +210,69 @@ class TemplatePairStackBlock(nn.Module):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
single
=
add
(
single
,
self
.
dropout_row
(
self
.
tri_att_start
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
)
),
_inplace
,
)
single
=
single
+
self
.
dropout_col
(
single
=
add
(
single
,
self
.
dropout_col
(
self
.
tri_att_end
(
single
,
chunk_size
=
chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
)
),
_inplace
,
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_out
(
tmu_update
=
self
.
tri_mul_out
(
single
,
mask
=
single_mask
)
mask
=
single_mask
,
_inplace
=
_inplace
,
_add_with_inplace
=
True
,
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
if
(
not
_inplace
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
single
,
mask
=
single_mask
)
mask
=
single_mask
,
_inplace
=
_inplace
,
_add_with_inplace
=
True
,
)
single
=
single
+
self
.
pair_transition
(
if
(
not
_inplace
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
single
=
add
(
single
,
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
),
_inplace
,
)
if
(
not
_inplace
):
single_templates
[
i
]
=
single
if
(
not
_inplace
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
...
...
@@ -328,6 +362,7 @@ class TemplatePairStack(nn.Module):
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
for
b
in
self
.
blocks
],
...
...
@@ -338,3 +373,223 @@ class TemplatePairStack(nn.Module):
t
=
self
.
layer_norm
(
t
)
return
t
def
embed_templates_offload
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
template_chunk_size
=
256
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
template_chunk_size:
Integer value controlling how quickly the offloaded pair embedding
tensor is brought back into GPU memory. In dire straits, can be
lowered to reduce memory consumption of this function even more.
Returns:
A dictionary of template pair and angle embeddings.
A version of the "embed_templates" method of the AlphaFold class that
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
model
.
template_pair_embedder
(
t
)
# [*, 1, N, N, C_z]
t
=
model
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
pair_embeds_cpu
.
append
(
t
.
cpu
())
del
t
# Preallocate the output tensor
t
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n
,
template_chunk_size
):
pair_chunks
=
[
p
[...,
i
:
i
+
template_chunk_size
,
:,
:]
for
p
in
pair_embeds_cpu
]
pair_chunk
=
torch
.
cat
(
pair_chunks
,
dim
=
templ_dim
).
to
(
device
=
z
.
device
)
z_chunk
=
z
[...,
i
:
i
+
template_chunk_size
,
:,
:]
att_chunk
=
model
.
template_pointwise_att
(
pair_chunk
,
z_chunk
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
t
[...,
i
:
i
+
template_chunk_size
,
:,
:]
=
att_chunk
del
pair_chunks
if
(
inplace_safe
):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
embed_templates_average
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
templ_group_size
=
2
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
templ_group_size:
Granularity of the approximation. Larger values trade memory for
greater proximity to the original function
Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the "embed_templates" method of the
AlphaFold class. Instead of running pointwise attention over pair
embeddings for all of the templates at the same time, it splits templates
into groups of size templ_group_size, computes embeddings for each group
normally, and then averages the group embeddings. In our experiments, this
approximation has a minimal effect on the quality of the resulting
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
out_tensor
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n_templ
,
templ_group_size
):
def
slice_template_tensor
(
t
):
s
=
[
slice
(
None
)
for
_
in
t
.
shape
]
s
[
templ_dim
]
=
slice
(
i
,
i
+
templ_group_size
)
return
t
[
s
]
template_feats
=
tensor_tree_map
(
slice_template_tensor
,
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
# [*, S_t, N, N, C_z]
t
=
model
.
template_pair_embedder
(
t
)
t
=
model
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
t
=
model
.
template_pointwise_att
(
t
,
z
,
template_mask
=
template_feats
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
denom
=
math
.
ceil
(
n_templ
/
templ_group_size
)
if
(
inplace_safe
):
t
/=
denom
else
:
t
=
t
/
denom
if
(
inplace_safe
):
out_tensor
+=
t
else
:
out_tensor
=
out_tensor
+
t
del
t
if
(
inplace_safe
):
out_tensor
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
out_tensor
=
out_tensor
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
out_tensor
})
return
ret
openfold/model/triangular_multiplicative_update.py
View file @
6e66b218
...
...
@@ -20,7 +20,7 @@ import torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
permute_final_dims
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
...
...
@@ -55,12 +55,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
_inplace_chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"This method needs to be overridden"
)
if
(
self
.
_outgoing
):
a
=
permute_final_dims
(
a
,
(
2
,
0
,
1
))
b
=
permute_final_dims
(
b
,
(
2
,
1
,
0
))
else
:
a
=
permute_final_dims
(
a
,
(
2
,
1
,
0
))
b
=
permute_final_dims
(
b
,
(
2
,
0
,
1
))
if
(
_inplace_chunk_size
is
not
None
):
# To be replaced by torch vmap
for
i
in
range
(
0
,
a
.
shape
[
-
3
],
_inplace_chunk_size
):
a_chunk
=
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
b_chunk
=
b
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
=
(
torch
.
matmul
(
a_chunk
,
b_chunk
,
)
)
p
=
a
else
:
p
=
torch
.
matmul
(
a
,
b
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_chunk_size
:
Optional
[
int
]
=
None
,
with_add
:
bool
=
True
,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the input tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" input tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that
loops over pairs of chunks of z: hereafter "columns" and "rows" of
z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing
output chunks to a new tensor would bring total memory consumption
down to 3x the size of z. However, more memory can be saved by writing
output chunks directly to z in-place. WLOG, we choose to write output
chunks vertically, overwriting the ith "column" of z at the end of
the ith iteration of the main loop. Despite this overwriting, the
ith column is always one column ahead of previously overwritten columns
and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For
this reason, we introduce the z-cache, a tensor one-half the size of
z. The z-cache initially contains the left half (2nd and 3rd quadrants)
of z. For 0 < i < N/2, the missing left part of the ith row of z is
recovered from this cache at the beginning of the ith iteration. Once i
exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th
quadrants of z instead. Though the 3rd quadrant of the original z is
entirely overwritten at this point, it can be recovered from the z-cache
itself. Thereafter, the ith row of z can be recovered in its entirety
from the reoriented z-cache. After the final iteration, z has been
completely overwritten and contains the triangular multiplicative
update. If with_add is True, it instead contains the sum of z and the
triangular multiplicative update. In either case, peak memory
consumption is just 2.5x the size of z, disregarding memory used for
chunks and other small variables.
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
,
a
=
True
):
if
(
a
):
linear_g
=
self
.
linear_a_g
linear_p
=
self
.
linear_a_p
else
:
linear_g
=
self
.
linear_b_g
linear_p
=
self
.
linear_b_p
pair
=
self
.
layer_norm_in
(
pair
)
p
=
linear_g
(
pair
)
p
.
sigmoid_
()
p
*=
linear_p
(
pair
)
p
*=
mask
p
=
permute_final_dims
(
p
,
(
2
,
0
,
1
))
return
p
def
compute_projection
(
pair
,
mask
,
a
=
True
,
chunked
=
True
):
need_transpose
=
self
.
_outgoing
^
a
if
(
not
chunked
):
p
=
compute_projection_helper
(
pair
,
mask
,
a
)
if
(
need_transpose
):
p
=
p
.
transpose
(
-
1
,
-
2
)
else
:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g
=
self
.
linear_a_g
if
a
else
self
.
linear_b_g
c
=
linear_g
.
bias
.
shape
[
-
1
]
out_shape
=
pair
.
shape
[:
-
3
]
+
(
c
,)
+
pair
.
shape
[
-
3
:
-
1
]
p
=
pair
.
new_zeros
(
out_shape
)
for
i
in
range
(
0
,
pair
.
shape
[
-
3
],
inplace_chunk_size
):
pair_chunk
=
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
mask_chunk
=
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
pair_chunk
=
compute_projection_helper
(
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
a
,
)
if
(
need_transpose
):
pair_chunk
=
pair_chunk
.
transpose
(
-
1
,
-
2
)
p
[...,
i
:
i
+
inplace_chunk_size
]
=
pair_chunk
else
:
p
[...,
i
:
i
+
inplace_chunk_size
,
:]
=
pair_chunk
del
pair_chunk
return
p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a
=
compute_projection
(
z
,
mask
,
True
,
chunked
=
True
)
if
(
inplace_chunk_size
is
not
None
):
n
=
a
.
shape
[
-
1
]
half_n
=
n
//
2
+
n
%
2
row_dim
=
-
3
col_dim
=
-
2
b_chunk_dim
=
row_dim
if
self
.
_outgoing
else
col_dim
def
empty_slicer
(
t
):
return
[
slice
(
None
)
for
_
in
t
.
shape
]
def
slice_tensor
(
t
,
start
,
end
,
dim
):
# Slices start:end from the dim dimension of t
s
=
empty_slicer
(
t
)
s
[
dim
]
=
slice
(
start
,
end
)
return
t
[
s
]
def
flip_z_cache_
(
z_cache
,
z
):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3
=
slice_tensor
(
z_cache
,
half_n
,
None
,
row_dim
)
z_cache
=
z_cache
.
transpose
(
row_dim
,
col_dim
)
# If n is odd, we need to shrink the z_cache by one row
z_cache
=
z_cache
[...,
:(
n
//
2
),
:,
:]
# Move the 3rd quadrant of z into the
first_half_slicer
=
empty_slicer
(
z_cache
)
first_half_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
[
first_half_slicer
]
=
quadrant_3
# Get the fourth quadrant of z
quadrant_4
=
slice_tensor
(
z
,
half_n
,
None
,
row_dim
)
quadrant_4
=
slice_tensor
(
quadrant_4
,
half_n
,
None
,
col_dim
)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer
=
empty_slicer
(
z_cache
)
quadrant_3_slicer
[
col_dim
]
=
slice
(
half_n
,
None
)
z_cache
[
quadrant_3_slicer
]
=
quadrant_4
return
z_cache
# Initialize the z cache to the left half of z.
z_cache_shape
=
list
(
z
.
shape
)
z_cache_shape
[
col_dim
]
=
half_n
z_cache
=
z
.
new_zeros
(
z_cache_shape
)
z_cache_slicer
=
empty_slicer
(
z_cache
)
z_cache_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
.
copy_
(
z
[
z_cache_slicer
])
z_cache_rotated
=
False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range
=
list
(
range
(
0
,
half_n
,
inplace_chunk_size
))
initial_offsets
=
[
i_2
-
i_1
for
i_1
,
i_2
in
zip
(
i_range
,
i_range
[
1
:]
+
[
half_n
])
]
after_half
=
list
(
range
(
half_n
,
n
,
inplace_chunk_size
))
after_half_offsets
=
[
inplace_chunk_size
for
_
in
after_half
]
combined_range_with_offsets
=
zip
(
i_range
+
after_half
,
initial_offsets
+
after_half_offsets
)
for
i
,
offset
in
combined_range_with_offsets
:
if
(
not
z_cache_rotated
and
i
>=
half_n
):
z_cache
=
flip_z_cache_
(
z_cache
,
z
)
z_cache_rotated
=
True
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
b_chunk_dim
,
)
mask_chunk
=
slice_tensor
(
mask
,
i
,
i
+
offset
,
b_chunk_dim
,
)
z_chunk_b
=
z_chunk_b
.
clone
()
if
(
b_chunk_dim
==
col_dim
):
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
else
:
# b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if
(
not
z_cache_rotated
):
z_chunk_slicer
=
empty_slicer
(
z_chunk_b
)
z_chunk_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_chunk_b
[
z_chunk_slicer
]
=
slice_tensor
(
z_cache
,
i
,
i
+
offset
,
row_dim
,
)
else
:
z_cache_offset
=
i
-
half_n
z_chunk_b
=
slice_tensor
(
z_cache
,
z_cache_offset
,
z_cache_offset
+
offset
,
row_dim
)
b_chunk
=
compute_projection
(
z_chunk_b
,
mask_chunk
,
a
=
False
,
chunked
=
False
)
del
z_chunk_b
x_chunk
=
torch
.
matmul
(
a
,
b_chunk
,
)
x_chunk
=
permute_final_dims
(
x_chunk
,
(
1
,
2
,
0
))
x_chunk
=
self
.
layer_norm_out
(
x_chunk
)
x_chunk
=
self
.
linear_z
(
x_chunk
)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
g_chunk
=
self
.
linear_g
(
self
.
layer_norm_in
(
z_chunk_g
))
g_chunk
.
sigmoid_
()
del
z_chunk_g
x_chunk
*=
g_chunk
# Write the columns into z in-place
z_slicer
=
empty_slicer
(
z
)
z_slicer
[
col_dim
]
=
slice
(
i
,
i
+
offset
)
if
(
with_add
):
z
[
z_slicer
]
+=
x_chunk
else
:
z
[
z_slicer
]
=
x_chunk
else
:
b
=
compute_projection
(
z
,
mask
,
False
,
False
)
x
=
torch
.
matmul
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z
)
g
.
sigmoid_
()
x
*=
g
if
(
with_add
):
z
+=
x
else
:
z
=
x
return
z
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_inplace
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -71,57 +369,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
(
_inplace
):
x
=
self
.
_inference_forward
(
z
,
mask
,
inplace_chunk_size
=
_inplace_chunk_size
,
with_add
=
_add_with_inplace
,
)
return
x
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
mask
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
mask
a
=
mask
a
=
a
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
a
*
self
.
linear_a_p
(
z
)
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
x
=
x
*
g
return
z
return
x
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 11.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
)
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
Implements Algorithm 12.
"""
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
openfold/np/protein.py
View file @
6e66b218
...
...
@@ -140,12 +140,20 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
residue_index
.
append
(
res
.
id
[
1
])
b_factors
.
append
(
res_b_factors
)
parents
=
None
if
(
"PARENT"
in
pdb_str
):
for
l
in
pdb_str
.
split
(
"
\n
"
):
if
(
"PARENT"
in
l
and
not
"N/A"
in
l
):
parents
=
l
.
split
()[
1
:]
break
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
b_factors
=
np
.
array
(
b_factors
),
parents
=
parents
,
)
...
...
openfold/np/relax/amber_minimize.py
View file @
6e66b218
...
...
@@ -516,6 +516,9 @@ def run_pipeline(
_check_residues_are_well_defined
(
prot
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
# We keep the input around to restore metadata deleted by the relaxer
input_prot
=
prot
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
set
(
exclude_residues
)
violations
=
np
.
inf
...
...
@@ -532,6 +535,11 @@ def run_pipeline(
max_attempts
=
max_attempts
,
use_gpu
=
use_gpu
,
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
ret
[
"min_pdb"
]
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
ret
[
"min_pdb"
]])
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
if
place_hydrogens_every_iteration
:
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
...
...
openfold/utils/exponential_moving_average.py
View file @
6e66b218
...
...
@@ -58,7 +58,8 @@ class ExponentialMovingAverage:
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
self
.
params
=
state_dict
[
"params"
]
for
k
in
state_dict
[
"params"
].
keys
():
self
.
params
[
k
]
=
state_dict
[
"params"
][
k
].
clone
()
self
.
decay
=
state_dict
[
"decay"
]
def
state_dict
(
self
)
->
OrderedDict
:
...
...
openfold/utils/loss.py
View file @
6e66b218
...
...
@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels):
def
sigmoid_cross_entropy
(
logits
,
labels
):
log_p
=
torch
.
log
(
torch
.
sigmoid
(
logits
))
log_not_p
=
torch
.
log
(
torch
.
sigmoid
(
-
logits
))
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
logits_dtype
=
logits
.
dtype
logits
=
logits
.
double
()
labels
=
labels
.
double
()
log_p
=
torch
.
nn
.
functional
.
logsigmoid
(
logits
)
# log_p = torch.log(torch.sigmoid(logits))
log_not_p
=
torch
.
nn
.
functional
.
logsigmoid
(
-
1
*
logits
)
# log_not_p = torch.log(torch.sigmoid(-logits))
loss
=
(
-
1.
*
labels
)
*
log_p
-
(
1.
-
labels
)
*
log_not_p
loss
=
loss
.
to
(
dtype
=
logits_dtype
)
return
loss
...
...
openfold/utils/tensor_utils.py
View file @
6e66b218
...
...
@@ -19,6 +19,17 @@ import torch.nn as nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
def
add
(
m1
,
m2
,
inplace
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
not
inplace
):
m1
=
m1
+
m2
else
:
m1
+=
m2
return
m1
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
...
...
run_pretrained_openfold.py
View file @
6e66b218
...
...
@@ -110,7 +110,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
# Prep protein metadata
template_domain_names
=
[]
template_chain_index
=
None
if
(
feature_processor
.
config
.
common
.
use_templates
):
if
(
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
):
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
...
...
scripts/precompute_alignments.py
View file @
6e66b218
...
...
@@ -227,7 +227,7 @@ if __name__ == "__main__":
)
add_data_args
(
parser
)
parser
.
add_argument
(
"--raise_errors"
,
type
=
bool
,
default
=
False
,
"--raise_errors"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to crash on parsing errors"
)
parser
.
add_argument
(
...
...
tests/test_triangular_multiplicative_update.py
View file @
6e66b218
...
...
@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
c
=
11
outgoing
=
True
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
outgoing
,
)
n_res
=
consts
.
c_z
...
...
@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
_inference_mode
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
self
.
assertTrue
(
torch
.
m
ax
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
m
ean
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
...
...
@@ -106,6 +105,40 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def
test_tri_mul_in_compare
(
self
):
self
.
_tri_mul_compare
(
incoming
=
True
)
def
_tri_mul_inference_mode
(
self
,
incoming
=
False
):
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
)
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
_inference_mode
=
False
,
).
cpu
()
# This has to come second because inference mode is in-place
out_inference_mode
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
_inference_mode
=
True
,
_inplace_chunk_size
=
2
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_stock
-
out_inference_mode
))
<
consts
.
eps
)
def
test_tri_mul_out_inference
(
self
):
self
.
_tri_mul_inference_mode
()
def
test_tri_mul_in_inference
(
self
):
self
.
_tri_mul_inference_mode
(
incoming
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment