Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6e66b218
"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "635f1e94e855d7832363ecdb2ed70affe487608a"
Commit
6e66b218
authored
Jun 10, 2022
by
Gustaf Ahdritz
Browse files
Vastly lower peak inference memory usage
parent
ec5619fc
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
970 additions
and
218 deletions
+970
-218
openfold/config.py
openfold/config.py
+44
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+7
-1
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+3
-2
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+1
-0
openfold/model/embedders.py
openfold/model/embedders.py
+17
-12
openfold/model/evoformer.py
openfold/model/evoformer.py
+89
-51
openfold/model/model.py
openfold/model/model.py
+100
-67
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+22
-6
openfold/model/template.py
openfold/model/template.py
+287
-32
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+322
-35
openfold/np/protein.py
openfold/np/protein.py
+8
-0
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+8
-0
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+2
-1
openfold/utils/loss.py
openfold/utils/loss.py
+11
-5
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+11
-0
run_pretrained_openfold.py
run_pretrained_openfold.py
+1
-1
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+1
-1
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+36
-3
No files found.
openfold/config.py
View file @
6e66b218
...
@@ -10,6 +10,29 @@ def set_inf(c, inf):
...
@@ -10,6 +10,29 @@ def set_inf(c, inf):
c
[
k
]
=
inf
c
[
k
]
=
inf
def
enforce_config_constraints
(
config
):
def
string_to_setting
(
s
):
path
=
s
.
split
(
'.'
)
setting
=
config
for
p
in
path
:
setting
=
setting
[
p
]
return
setting
mutually_exclusive_bools
=
[
(
"model.template.average_templates"
,
"model.template.offload_templates"
)
]
for
s1
,
s2
in
mutually_exclusive_bools
:
s1_setting
=
string_to_setting
(
s1
)
s2_setting
=
string_to_setting
(
s2
)
if
(
s1_setting
and
s2_setting
):
raise
ValueError
(
f
"Only one of
{
s1
}
and
{
s2
}
may be set at a time"
)
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
def
model_config
(
name
,
train
=
False
,
low_prec
=
False
):
c
=
copy
.
deepcopy
(
config
)
c
=
copy
.
deepcopy
(
config
)
if
name
==
"initial_training"
:
if
name
==
"initial_training"
:
...
@@ -22,6 +45,14 @@ def model_config(name, train=False, low_prec=False):
...
@@ -22,6 +45,14 @@ def model_config(name, train=False, low_prec=False):
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
elif
name
==
"finetuning_ptm"
:
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
512
c
.
loss
.
violation
.
weight
=
1.
c
.
loss
.
experimentally_resolved
.
weight
=
0.01
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
name
==
"model_1"
:
elif
name
==
"model_1"
:
# AF2 Suppl. Table 5, Model 1.1.1
# AF2 Suppl. Table 5, Model 1.1.1
c
.
data
.
train
.
max_extra_msa
=
5120
c
.
data
.
train
.
max_extra_msa
=
5120
...
@@ -95,6 +126,8 @@ def model_config(name, train=False, low_prec=False):
...
@@ -95,6 +126,8 @@ def model_config(name, train=False, low_prec=False):
# a global constant
# a global constant
set_inf
(
c
,
1e4
)
set_inf
(
c
,
1e4
)
enforce_config_constraints
(
c
)
return
c
return
c
...
@@ -346,6 +379,16 @@ config = mlc.ConfigDict(
...
@@ -346,6 +379,16 @@ config = mlc.ConfigDict(
"enabled"
:
templates_enabled
,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
False
,
"use_unit_vector"
:
False
,
# Approximate template computation, saving memory.
# In our experiments, results are equivalent to or better than
# the stock implementation. Should be enabled for all new
# training runs.
"average_templates"
:
False
,
# Offload template embeddings to CPU memory. Vastly reduced
# memory consumption at the cost of a modest increase in
# runtime. Useful for inference on very long sequences.
# Mutually exclusive with average_templates.
"offload_templates"
:
False
,
},
},
"extra_msa"
:
{
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"extra_msa_embedder"
:
{
...
@@ -498,7 +541,7 @@ config = mlc.ConfigDict(
...
@@ -498,7 +541,7 @@ config = mlc.ConfigDict(
"min_resolution"
:
0.1
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.
0
,
"weight"
:
0.
,
"enabled"
:
tm_enabled
,
"enabled"
:
tm_enabled
,
},
},
"eps"
:
eps
,
"eps"
:
eps
,
...
...
openfold/data/data_modules.py
View file @
6e66b218
...
@@ -625,13 +625,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -625,13 +625,20 @@ class OpenFoldDataModule(pl.LightningDataModule):
self
.
train_chain_data_cache_path
,
self
.
train_chain_data_cache_path
,
]
]
generator
=
None
if
(
self
.
batch_seed
is
not
None
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_seed
+
1
)
self
.
train_dataset
=
OpenFoldDataset
(
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
datasets
,
datasets
=
datasets
,
probabilities
=
probabilities
,
probabilities
=
probabilities
,
epoch_len
=
self
.
train_epoch_len
,
epoch_len
=
self
.
train_epoch_len
,
chain_data_cache_paths
=
chain_data_cache_paths
,
chain_data_cache_paths
=
chain_data_cache_paths
,
generator
=
generator
,
_roll_at_init
=
False
,
_roll_at_init
=
False
,
)
)
if
(
self
.
val_data_dir
is
not
None
):
if
(
self
.
val_data_dir
is
not
None
):
self
.
eval_dataset
=
dataset_gen
(
self
.
eval_dataset
=
dataset_gen
(
...
@@ -660,7 +667,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -660,7 +667,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
dataset
=
None
dataset
=
None
if
(
stage
==
"train"
):
if
(
stage
==
"train"
):
dataset
=
self
.
train_dataset
dataset
=
self
.
train_dataset
# Filter the dataset, if necessary
# Filter the dataset, if necessary
dataset
.
reroll
()
dataset
.
reroll
()
elif
(
stage
==
"eval"
):
elif
(
stage
==
"eval"
):
...
...
openfold/data/data_pipeline.py
View file @
6e66b218
...
@@ -97,7 +97,8 @@ def unify_template_features(
...
@@ -97,7 +97,8 @@ def unify_template_features(
chain_indices
=
np
.
array
(
n_templates
*
[
i
])
chain_indices
=
np
.
array
(
n_templates
*
[
i
])
out_dict
[
"template_chain_index"
]
=
chain_indices
out_dict
[
"template_chain_index"
]
=
chain_indices
out_dicts
.
append
(
out_dict
)
if
(
n_templates
!=
0
):
out_dicts
.
append
(
out_dict
)
out_dict
=
{
out_dict
=
{
k
:
np
.
concatenate
([
od
[
k
]
for
od
in
out_dicts
])
for
k
in
out_dicts
[
0
]
k
:
np
.
concatenate
([
od
[
k
]
for
od
in
out_dicts
])
for
k
in
out_dicts
[
0
]
...
@@ -741,7 +742,7 @@ class DataPipeline:
...
@@ -741,7 +742,7 @@ class DataPipeline:
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter
. No templates
.
hack from Twitter
(a.k.a. AlphaFold-Gap)
.
"""
"""
with
open
(
fasta_path
,
'r'
)
as
f
:
with
open
(
fasta_path
,
'r'
)
as
f
:
fasta_str
=
f
.
read
()
fasta_str
=
f
.
read
()
...
...
openfold/data/data_transforms.py
View file @
6e66b218
...
@@ -728,6 +728,7 @@ def make_atom14_positions(protein):
...
@@ -728,6 +728,7 @@ def make_atom14_positions(protein):
for
index
,
correspondence
in
enumerate
(
correspondences
):
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.0
renaming_matrix
[
index
,
correspondence
]
=
1.0
all_matrices
[
resname
]
=
renaming_matrix
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
)
...
...
openfold/model/embedders.py
View file @
6e66b218
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
one_hot
from
openfold.utils.tensor_utils
import
add
,
one_hot
class
InputEmbedder
(
nn
.
Module
):
class
InputEmbedder
(
nn
.
Module
):
...
@@ -132,7 +132,6 @@ class RecyclingEmbedder(nn.Module):
...
@@ -132,7 +132,6 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
Implements Algorithm 32.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
c_m
:
int
,
c_m
:
int
,
...
@@ -174,6 +173,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -174,6 +173,7 @@ class RecyclingEmbedder(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
_inplace
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
...
@@ -189,6 +189,19 @@ class RecyclingEmbedder(nn.Module):
...
@@ -189,6 +189,19 @@ class RecyclingEmbedder(nn.Module):
z:
z:
[*, N_res, N_res, C_z] pair embedding update
[*, N_res, N_res, C_z] pair embedding update
"""
"""
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
if
(
_inplace
):
m
.
copy_
(
m_update
)
m_update
=
m
# [*, N, N, C_z]
z_update
=
self
.
layer_norm_z
(
z
)
if
(
_inplace
):
z
.
copy_
(
z_update
)
z_update
=
z
# This squared method might become problematic in FP16 mode.
bins
=
torch
.
linspace
(
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
min_bin
,
self
.
max_bin
,
self
.
max_bin
,
...
@@ -197,13 +210,6 @@ class RecyclingEmbedder(nn.Module):
...
@@ -197,13 +210,6 @@ class RecyclingEmbedder(nn.Module):
device
=
x
.
device
,
device
=
x
.
device
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins
=
bins
**
2
squared_bins
=
bins
**
2
upper
=
torch
.
cat
(
upper
=
torch
.
cat
(
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
...
@@ -217,7 +223,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -217,7 +223,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
d
=
self
.
linear
(
d
)
d
=
self
.
linear
(
d
)
z_update
=
d
+
self
.
layer_norm_z
(
z
)
z_update
=
add
(
z_update
,
d
,
_inplace
)
return
m_update
,
z_update
return
m_update
,
z_update
...
@@ -315,7 +321,6 @@ class ExtraMSAEmbedder(nn.Module):
...
@@ -315,7 +321,6 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
Implements Algorithm 2, line 15
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
c_in
:
int
,
c_in
:
int
,
...
...
openfold/model/evoformer.py
View file @
6e66b218
...
@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -37,7 +37,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
class
MSATransition
(
nn
.
Module
):
class
MSATransition
(
nn
.
Module
):
...
@@ -192,32 +192,76 @@ class EvoformerBlockCore(nn.Module):
...
@@ -192,32 +192,76 @@ class EvoformerBlockCore(nn.Module):
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_transition
(
# Need to dodge activation checkpoints
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
m
=
add
(
m
,
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
)
z
=
z
+
self
.
outer_product_mean
(
z
=
add
(
z
,
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_inplace
=
inplace_safe
),
inplace
=
inplace_safe
,
)
)
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
))
z
=
z
+
self
.
ps_dropout_row_layer
(
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
))
tmu_update
=
self
.
tri_mul_out
(
z
=
z
+
self
.
ps_dropout_row_layer
(
z
,
self
.
tri_att_start
(
mask
=
pair_mask
,
z
,
_inplace
=
inplace_safe
,
mask
=
pair_mask
,
_add_with_inplace
=
True
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
)
)
z
=
z
+
self
.
ps_dropout_col_layer
(
if
(
not
inplace_safe
):
self
.
tri_att_end
(
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
z
,
else
:
mask
=
pair_mask
,
z
=
tmu_update
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
del
tmu_update
)
tmu_update
=
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
,
_inplace
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
z
=
z
+
self
.
ps_dropout_row_layer
(
tmu_update
)
else
:
z
=
tmu_update
del
tmu_update
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_start
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
)
),
inplace
=
inplace_safe
,
)
z
=
add
(
z
,
self
.
ps_dropout_col_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
)
),
inplace
=
inplace_safe
,
)
)
z
=
z
+
self
.
pair_transition
(
z
=
add
(
z
,
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
)
return
m
,
z
return
m
,
z
...
@@ -377,40 +421,35 @@ class ExtraMSABlock(nn.Module):
...
@@ -377,40 +421,35 @@ class ExtraMSABlock(nn.Module):
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
add
(
m1
,
m2
):
# If function calls could speak...
# The first operation in a checkpoint can't be in-place, but it's
m
=
add
(
m
,
# nice to have in-place addition during inference. Thus...
self
.
msa_dropout_layer
(
if
(
torch
.
is_grad_enabled
()):
self
.
msa_att_row
(
m1
=
m1
+
m2
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
else
:
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
m1
+=
m2
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
return
m1
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
m
=
add
(
m
,
self
.
msa_dropout_layer
(
_chunk_logits
=
self
.
msa_att_row
(
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
_checkpoint_chunks
=
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
mask
=
msa_mask
,
)
chunk_size
=
chunk_size
,
),
use_lma
=
use_lma
,
inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
)
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
))
def
fn
(
m
,
z
):
def
fn
(
m
,
z
):
m
=
add
(
m
=
add
(
m
,
m
,
self
.
msa_att_col
(
self
.
msa_att_col
(
m
,
m
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
)
),
inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
)
m
,
z
=
self
.
core
(
m
,
z
=
self
.
core
(
m
,
m
,
...
@@ -590,7 +629,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -590,7 +629,6 @@ class ExtraMSAStack(nn.Module):
"""
"""
Implements Algorithm 18.
Implements Algorithm 18.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
...
...
openfold/model/model.py
View file @
6e66b218
...
@@ -12,18 +12,12 @@
...
@@ -12,18 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
from
functools
import
partial
import
weakref
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
atom14_to_atom37
,
)
from
openfold.model.embedders
import
(
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedder
,
RecyclingEmbedder
,
RecyclingEmbedder
,
...
@@ -33,16 +27,26 @@ from openfold.model.embedders import (
...
@@ -33,16 +27,26 @@ from openfold.model.embedders import (
)
)
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.heads
import
AuxiliaryHeads
from
openfold.model.heads
import
AuxiliaryHeads
import
openfold.np.residue_constants
as
residue_constants
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.template
import
(
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePairStack
,
TemplatePointwiseAttention
,
TemplatePointwiseAttention
,
embed_templates_average
,
embed_templates_offload
,
)
import
openfold.np.residue_constants
as
residue_constants
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
atom14_to_atom37
,
)
)
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
compute_plddt
,
compute_plddt
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
add
,
dict_multimap
,
dict_multimap
,
tensor_tree_map
,
tensor_tree_map
,
)
)
...
@@ -64,52 +68,71 @@ class AlphaFold(nn.Module):
...
@@ -64,52 +68,71 @@ class AlphaFold(nn.Module):
super
(
AlphaFold
,
self
).
__init__
()
super
(
AlphaFold
,
self
).
__init__
()
self
.
globals
=
config
.
globals
self
.
globals
=
config
.
globals
config
=
config
.
model
self
.
config
=
config
.
model
template_config
=
config
.
template
self
.
template_config
=
self
.
config
.
template
extra_msa_config
=
config
.
extra_msa
self
.
extra_msa_config
=
self
.
config
.
extra_msa
# Main trunk + structure module
# Main trunk + structure module
self
.
input_embedder
=
InputEmbedder
(
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
self
.
config
[
"input_embedder"
],
)
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
**
self
.
config
[
"recycling_embedder"
],
)
)
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
template_config
[
"template_angle_embedder"
],
**
self
.
template_config
[
"template_angle_embedder"
],
)
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
template_config
[
"template_pair_embedder"
],
**
self
.
template_config
[
"template_pair_embedder"
],
)
)
self
.
template_pair_stack
=
TemplatePairStack
(
self
.
template_pair_stack
=
TemplatePairStack
(
**
template_config
[
"template_pair_stack"
],
**
self
.
template_config
[
"template_pair_stack"
],
)
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
template_config
[
"template_pointwise_attention"
],
**
self
.
template_config
[
"template_pointwise_attention"
],
)
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
**
self
.
extra_msa_config
[
"extra_msa_embedder"
],
)
)
self
.
extra_msa_stack
=
ExtraMSAStack
(
self
.
extra_msa_stack
=
ExtraMSAStack
(
**
extra_msa_config
[
"extra_msa_stack"
],
**
self
.
extra_msa_config
[
"extra_msa_stack"
],
)
)
self
.
evoformer
=
EvoformerStack
(
self
.
evoformer
=
EvoformerStack
(
**
config
[
"evoformer_stack"
],
**
self
.
config
[
"evoformer_stack"
],
)
)
self
.
structure_module
=
StructureModule
(
self
.
structure_module
=
StructureModule
(
**
config
[
"structure_module"
],
**
self
.
config
[
"structure_module"
],
)
)
self
.
aux_heads
=
AuxiliaryHeads
(
self
.
aux_heads
=
AuxiliaryHeads
(
config
[
"heads"
],
self
.
config
[
"heads"
],
)
)
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
)
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
)
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
if
(
inplace_safe
):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair
=
z
.
new_zeros
(
z
.
shape
[:
-
3
]
+
(
n_templ
,
n
,
n
,
self
.
globals
.
c_t
)
)
for
i
in
range
(
n_templ
):
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
single_template_feats
=
tensor_tree_map
(
...
@@ -117,18 +140,7 @@ class AlphaFold(nn.Module):
...
@@ -117,18 +140,7 @@ class AlphaFold(nn.Module):
batch
,
batch
,
)
)
single_template_embeds
=
{}
# [*, N, N, C_t]
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
t
=
build_template_pair_feat
(
single_template_feats
,
single_template_feats
,
use_unit_vector
=
self
.
config
.
template
.
use_unit_vector
,
use_unit_vector
=
self
.
config
.
template
.
use_unit_vector
,
...
@@ -138,23 +150,27 @@ class AlphaFold(nn.Module):
...
@@ -138,23 +150,27 @@ class AlphaFold(nn.Module):
).
to
(
z
.
dtype
)
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
if
(
inplace_safe
):
t_pair
[...,
i
,
:,
:,
:]
=
t
else
:
pair_embeds
.
append
(
t
)
del
t
template_embeds
.
append
(
single_template_embeds
)
if
(
not
inplace_safe
):
t_pair
=
torch
.
cat
(
pair_embeds
,
dim
=
templ_dim
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
del
pair_embeds
template_embeds
,
)
# [*, S_t, N, N, C_z]
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
t
=
self
.
template_pair_stack
(
t
emplate_embeds
[
"
pair
"
]
,
t
_
pair
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
del
t_pair
# [*, N, N, C_z]
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
=
self
.
template_pointwise_att
(
...
@@ -164,17 +180,28 @@ class AlphaFold(nn.Module):
...
@@ -164,17 +180,28 @@ class AlphaFold(nn.Module):
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
)
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
if
(
inplace_safe
):
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
ret
=
{}
if
self
.
config
.
template
.
embed_angles
:
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
template_angle_feat
=
build_template_angle_feat
(
batch
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_
prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prev
s
,
_recycle
=
True
):
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -190,13 +217,14 @@ class AlphaFold(nn.Module):
...
@@ -190,13 +217,14 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
device
=
feats
[
"target_feat"
].
device
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
# Prep some features
seq_mask
=
feats
[
"seq_mask"
]
seq_mask
=
feats
[
"seq_mask"
]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
pair_mask
=
seq_mask
[...,
None
]
*
seq_mask
[...,
None
,
:]
msa_mask
=
feats
[
"msa_mask"
]
msa_mask
=
feats
[
"msa_mask"
]
# Initialize the MSA and pair representations
#
#
Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
...
@@ -206,7 +234,11 @@ class AlphaFold(nn.Module):
...
@@ -206,7 +234,11 @@ class AlphaFold(nn.Module):
feats
[
"msa_feat"
],
feats
[
"msa_feat"
],
)
)
# Initialize the recycling embeddings, if needs be
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function.
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
# [*, N, C_m]
# [*, N, C_m]
m_1_prev
=
m
.
new_zeros
(
m_1_prev
=
m
.
new_zeros
(
...
@@ -236,24 +268,16 @@ class AlphaFold(nn.Module):
...
@@ -236,24 +268,16 @@ class AlphaFold(nn.Module):
m_1_prev
,
m_1_prev
,
z_prev
,
z_prev
,
x_prev
,
x_prev
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m]
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
# [*, N, N, C_z]
# [*, N, N, C_z]
z
+=
z_prev_emb
z
+=
z_prev_emb
#
Possibly prevents memory fragmentation
#
This matters during inference with large N
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
...
@@ -269,7 +293,10 @@ class AlphaFold(nn.Module):
...
@@ -269,7 +293,10 @@ class AlphaFold(nn.Module):
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
add
(
z
,
template_embeds
.
pop
(
"template_pair_embedding"
),
inplace_safe
,
)
if
self
.
config
.
template
.
embed_angles
:
if
self
.
config
.
template
.
embed_angles
:
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
...
@@ -289,7 +316,7 @@ class AlphaFold(nn.Module):
...
@@ -289,7 +316,7 @@ class AlphaFold(nn.Module):
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
(
a
,
a
,
...
@@ -301,6 +328,8 @@ class AlphaFold(nn.Module):
...
@@ -301,6 +328,8 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
del
a
# Run MSA + pair embeddings through the trunk of the network
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
...
@@ -416,6 +445,7 @@ class AlphaFold(nn.Module):
...
@@ -416,6 +445,7 @@ class AlphaFold(nn.Module):
"""
"""
# Initialize recycling embeddings
# Initialize recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
# Disable activation checkpointing for the first few recycling iters
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled
=
torch
.
is_grad_enabled
()
is_grad_enabled
=
torch
.
is_grad_enabled
()
...
@@ -440,12 +470,15 @@ class AlphaFold(nn.Module):
...
@@ -440,12 +470,15 @@ class AlphaFold(nn.Module):
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
feats
,
m_1_prev
,
prevs
,
z_prev
,
x_prev
,
_recycle
=
(
num_iters
>
1
)
_recycle
=
(
num_iters
>
1
)
)
)
if
(
not
is_final_iter
):
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/outer_product_mean.py
View file @
6e66b218
...
@@ -82,7 +82,13 @@ class OuterProductMean(nn.Module):
...
@@ -82,7 +82,13 @@ class OuterProductMean(nn.Module):
no_batch_dims
=
1
,
no_batch_dims
=
1
,
)
)
out
.
append
(
outer
)
out
.
append
(
outer
)
outer
=
torch
.
stack
(
out
,
dim
=
0
)
# For some cursed reason making this distinction saves memory
if
(
len
(
out
)
==
1
):
outer
=
out
[
0
].
unsqueeze
(
0
)
else
:
outer
=
torch
.
stack
(
out
,
dim
=
0
)
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
outer
=
outer
.
reshape
(
a
.
shape
[:
-
3
]
+
outer
.
shape
[
1
:])
return
outer
return
outer
...
@@ -90,7 +96,8 @@ class OuterProductMean(nn.Module):
...
@@ -90,7 +96,8 @@ class OuterProductMean(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
_inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -105,12 +112,17 @@ class OuterProductMean(nn.Module):
...
@@ -105,12 +112,17 @@ class OuterProductMean(nn.Module):
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
# [*, N_seq, N_res, C_m]
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm
(
m
)
ln
=
self
.
layer_norm
(
m
)
# [*, N_seq, N_res, C]
# [*, N_seq, N_res, C]
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
a
=
self
.
linear_1
(
m
)
*
mask
a
=
self
.
linear_1
(
ln
)
b
=
self
.
linear_2
(
m
)
*
mask
a
=
a
*
mask
b
=
self
.
linear_2
(
ln
)
b
=
b
*
mask
del
ln
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
...
@@ -122,8 +134,12 @@ class OuterProductMean(nn.Module):
...
@@ -122,8 +134,12 @@ class OuterProductMean(nn.Module):
# [*, N_res, N_res, 1]
# [*, N_res, N_res, 1]
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
outer
=
outer
/
(
self
.
eps
+
norm
)
if
(
_inplace
):
outer
/=
norm
else
:
outer
=
outer
/
norm
return
outer
return
outer
openfold/model/template.py
View file @
6e66b218
...
@@ -34,10 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -34,10 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.feats
import
(
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
add
,
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
)
)
...
@@ -191,7 +197,8 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -191,7 +197,8 @@ class TemplatePairStackBlock(nn.Module):
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
_mask_trans
:
bool
=
True
,
_inplace
:
bool
=
False
,
):
):
single_templates
=
[
single_templates
=
[
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
t
.
unsqueeze
(
-
4
)
for
t
in
torch
.
unbind
(
z
,
dim
=-
4
)
...
@@ -202,44 +209,71 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -202,44 +209,71 @@ class TemplatePairStackBlock(nn.Module):
for
i
in
range
(
len
(
single_templates
)):
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single_mask
=
single_templates_masks
[
i
]
single
=
single
+
self
.
dropout_row
(
single
=
add
(
single
,
self
.
tri_att_start
(
self
.
dropout_row
(
single
,
self
.
tri_att_start
(
chunk_size
=
chunk_size
,
single
,
mask
=
single_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
mask
=
single_mask
,
)
use_lma
=
use_lma
,
)
),
_inplace
,
)
)
single
=
single
+
self
.
dropout_col
(
self
.
tri_att_end
(
single
=
add
(
single
,
single
,
self
.
dropout_col
(
chunk_size
=
chunk_size
,
self
.
tri_att_end
(
mask
=
single_mask
,
single
,
use_lma
=
use_lma
,
chunk_size
=
chunk_size
,
)
mask
=
single_mask
,
)
use_lma
=
use_lma
,
single
=
single
+
self
.
dropout_row
(
)
self
.
tri_mul_out
(
),
single
,
_inplace
,
mask
=
single_mask
)
)
)
single
=
single
+
self
.
dropout_row
(
self
.
tri_mul_in
(
tmu_update
=
self
.
tri_mul_out
(
single
,
single
,
mask
=
single_mask
mask
=
single_mask
,
)
_inplace
=
_inplace
,
_add_with_inplace
=
True
,
)
)
single
=
single
+
self
.
pair_transition
(
if
(
not
_inplace
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
tmu_update
=
self
.
tri_mul_in
(
single
,
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
mask
=
single_mask
,
chunk_size
=
chunk_size
,
_inplace
=
_inplace
,
_add_with_inplace
=
True
,
)
if
(
not
_inplace
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
del
tmu_update
single
=
add
(
single
,
self
.
pair_transition
(
single
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
),
_inplace
,
)
)
single_templates
[
i
]
=
single
if
(
not
_inplace
):
single_templates
[
i
]
=
single
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
if
(
not
_inplace
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
return
z
...
@@ -328,6 +362,7 @@ class TemplatePairStack(nn.Module):
...
@@ -328,6 +362,7 @@ class TemplatePairStack(nn.Module):
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
],
],
...
@@ -338,3 +373,223 @@ class TemplatePairStack(nn.Module):
...
@@ -338,3 +373,223 @@ class TemplatePairStack(nn.Module):
t
=
self
.
layer_norm
(
t
)
t
=
self
.
layer_norm
(
t
)
return
t
return
t
def
embed_templates_offload
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
template_chunk_size
=
256
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
template_chunk_size:
Integer value controlling how quickly the offloaded pair embedding
tensor is brought back into GPU memory. In dire straits, can be
lowered to reduce memory consumption of this function even more.
Returns:
A dictionary of template pair and angle embeddings.
A version of the "embed_templates" method of the AlphaFold class that
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
model
.
template_pair_embedder
(
t
)
# [*, 1, N, N, C_z]
t
=
model
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
pair_embeds_cpu
.
append
(
t
.
cpu
())
del
t
# Preallocate the output tensor
t
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n
,
template_chunk_size
):
pair_chunks
=
[
p
[...,
i
:
i
+
template_chunk_size
,
:,
:]
for
p
in
pair_embeds_cpu
]
pair_chunk
=
torch
.
cat
(
pair_chunks
,
dim
=
templ_dim
).
to
(
device
=
z
.
device
)
z_chunk
=
z
[...,
i
:
i
+
template_chunk_size
,
:,
:]
att_chunk
=
model
.
template_pointwise_att
(
pair_chunk
,
z_chunk
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
t
[...,
i
:
i
+
template_chunk_size
,
:,
:]
=
att_chunk
del
pair_chunks
if
(
inplace_safe
):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
t
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
embed_templates_average
(
model
,
batch
,
z
,
pair_mask
,
templ_dim
,
templ_group_size
=
2
,
):
"""
Args:
model:
An AlphaFold model object
batch:
An AlphaFold input batch. See documentation of AlphaFold.
z:
A [*, N, N, C_z] pair embedding
pair_mask:
A [*, N, N] pair mask
templ_dim:
The template dimension of the template tensors in batch
templ_group_size:
Granularity of the approximation. Larger values trade memory for
greater proximity to the original function
Returns:
A dictionary of template pair and angle embeddings.
A memory-efficient approximation of the "embed_templates" method of the
AlphaFold class. Instead of running pointwise attention over pair
embeddings for all of the templates at the same time, it splits templates
into groups of size templ_group_size, computes embeddings for each group
normally, and then averages the group embeddings. In our experiments, this
approximation has a minimal effect on the quality of the resulting
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
out_tensor
=
z
.
new_zeros
(
z
.
shape
)
for
i
in
range
(
0
,
n_templ
,
templ_group_size
):
def
slice_template_tensor
(
t
):
s
=
[
slice
(
None
)
for
_
in
t
.
shape
]
s
[
templ_dim
]
=
slice
(
i
,
i
+
templ_group_size
)
return
t
[
s
]
template_feats
=
tensor_tree_map
(
slice_template_tensor
,
batch
,
)
# [*, N, N, C_t]
t
=
build_template_pair_feat
(
template_feats
,
use_unit_vector
=
model
.
config
.
template
.
use_unit_vector
,
inf
=
model
.
config
.
template
.
inf
,
eps
=
model
.
config
.
template
.
eps
,
**
model
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
# [*, S_t, N, N, C_z]
t
=
model
.
template_pair_embedder
(
t
)
t
=
model
.
template_pair_stack
(
t
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
model
.
globals
.
chunk_size
,
use_lma
=
model
.
globals
.
use_lma
,
_mask_trans
=
model
.
config
.
_mask_trans
,
)
t
=
model
.
template_pointwise_att
(
t
,
z
,
template_mask
=
template_feats
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
use_lma
=
model
.
globals
.
use_lma
,
)
denom
=
math
.
ceil
(
n_templ
/
templ_group_size
)
if
(
inplace_safe
):
t
/=
denom
else
:
t
=
t
/
denom
if
(
inplace_safe
):
out_tensor
+=
t
else
:
out_tensor
=
out_tensor
+
t
del
t
if
(
inplace_safe
):
out_tensor
*=
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
else
:
out_tensor
=
out_tensor
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
if
model
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
,
)
# [*, N, C_m]
a
=
model
.
template_angle_embedder
(
template_angle_feat
)
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
out_tensor
})
return
ret
openfold/model/triangular_multiplicative_update.py
View file @
6e66b218
...
@@ -20,7 +20,7 @@ import torch
...
@@ -20,7 +20,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
permute_final_dims
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
...
@@ -55,12 +55,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -55,12 +55,310 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
_combine_projections
(
self
,
def
_combine_projections
(
self
,
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
_inplace_chunk_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"This method needs to be overridden"
)
if
(
self
.
_outgoing
):
a
=
permute_final_dims
(
a
,
(
2
,
0
,
1
))
b
=
permute_final_dims
(
b
,
(
2
,
1
,
0
))
else
:
a
=
permute_final_dims
(
a
,
(
2
,
1
,
0
))
b
=
permute_final_dims
(
b
,
(
2
,
0
,
1
))
if
(
_inplace_chunk_size
is
not
None
):
# To be replaced by torch vmap
for
i
in
range
(
0
,
a
.
shape
[
-
3
],
_inplace_chunk_size
):
a_chunk
=
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
b_chunk
=
b
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
a
[...,
i
:
i
+
_inplace_chunk_size
,
:,
:]
=
(
torch
.
matmul
(
a_chunk
,
b_chunk
,
)
)
p
=
a
else
:
p
=
torch
.
matmul
(
a
,
b
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
def
_inference_forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_chunk_size
:
Optional
[
int
]
=
None
,
with_add
:
bool
=
True
,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the input tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" input tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that
loops over pairs of chunks of z: hereafter "columns" and "rows" of
z, even though each "column" and "row" in fact contains
inplace_chunk_size contiguous true columns and rows of z. Writing
output chunks to a new tensor would bring total memory consumption
down to 3x the size of z. However, more memory can be saved by writing
output chunks directly to z in-place. WLOG, we choose to write output
chunks vertically, overwriting the ith "column" of z at the end of
the ith iteration of the main loop. Despite this overwriting, the
ith column is always one column ahead of previously overwritten columns
and can be recovered directly from z. After the first iteration,
however, the ith row of z is always at least partially overwritten. For
this reason, we introduce the z-cache, a tensor one-half the size of
z. The z-cache initially contains the left half (2nd and 3rd quadrants)
of z. For 0 < i < N/2, the missing left part of the ith row of z is
recovered from this cache at the beginning of the ith iteration. Once i
exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th
quadrants of z instead. Though the 3rd quadrant of the original z is
entirely overwritten at this point, it can be recovered from the z-cache
itself. Thereafter, the ith row of z can be recovered in its entirety
from the reoriented z-cache. After the final iteration, z has been
completely overwritten and contains the triangular multiplicative
update. If with_add is True, it instead contains the sum of z and the
triangular multiplicative update. In either case, peak memory
consumption is just 2.5x the size of z, disregarding memory used for
chunks and other small variables.
"""
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
,
a
=
True
):
if
(
a
):
linear_g
=
self
.
linear_a_g
linear_p
=
self
.
linear_a_p
else
:
linear_g
=
self
.
linear_b_g
linear_p
=
self
.
linear_b_p
pair
=
self
.
layer_norm_in
(
pair
)
p
=
linear_g
(
pair
)
p
.
sigmoid_
()
p
*=
linear_p
(
pair
)
p
*=
mask
p
=
permute_final_dims
(
p
,
(
2
,
0
,
1
))
return
p
def
compute_projection
(
pair
,
mask
,
a
=
True
,
chunked
=
True
):
need_transpose
=
self
.
_outgoing
^
a
if
(
not
chunked
):
p
=
compute_projection_helper
(
pair
,
mask
,
a
)
if
(
need_transpose
):
p
=
p
.
transpose
(
-
1
,
-
2
)
else
:
# This computation is chunked so as not to exceed our 2.5x
# budget with a large intermediate tensor
linear_g
=
self
.
linear_a_g
if
a
else
self
.
linear_b_g
c
=
linear_g
.
bias
.
shape
[
-
1
]
out_shape
=
pair
.
shape
[:
-
3
]
+
(
c
,)
+
pair
.
shape
[
-
3
:
-
1
]
p
=
pair
.
new_zeros
(
out_shape
)
for
i
in
range
(
0
,
pair
.
shape
[
-
3
],
inplace_chunk_size
):
pair_chunk
=
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
mask_chunk
=
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:]
pair_chunk
=
compute_projection_helper
(
pair
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
mask
[...,
i
:
i
+
inplace_chunk_size
,
:,
:],
a
,
)
if
(
need_transpose
):
pair_chunk
=
pair_chunk
.
transpose
(
-
1
,
-
2
)
p
[...,
i
:
i
+
inplace_chunk_size
]
=
pair_chunk
else
:
p
[...,
i
:
i
+
inplace_chunk_size
,
:]
=
pair_chunk
del
pair_chunk
return
p
# We start by fully manifesting a. In addition to the input, this
# brings total memory consumption to 2x z (disregarding size of chunks)
# [*, N, N, c]
a
=
compute_projection
(
z
,
mask
,
True
,
chunked
=
True
)
if
(
inplace_chunk_size
is
not
None
):
n
=
a
.
shape
[
-
1
]
half_n
=
n
//
2
+
n
%
2
row_dim
=
-
3
col_dim
=
-
2
b_chunk_dim
=
row_dim
if
self
.
_outgoing
else
col_dim
def
empty_slicer
(
t
):
return
[
slice
(
None
)
for
_
in
t
.
shape
]
def
slice_tensor
(
t
,
start
,
end
,
dim
):
# Slices start:end from the dim dimension of t
s
=
empty_slicer
(
t
)
s
[
dim
]
=
slice
(
start
,
end
)
return
t
[
s
]
def
flip_z_cache_
(
z_cache
,
z
):
# "Reorient" the z_cache (see below), filling it with quadrants
# 3---recovered from the z_cache---and 4---recovered from z---
# of the input tensor z.
quadrant_3
=
slice_tensor
(
z_cache
,
half_n
,
None
,
row_dim
)
z_cache
=
z_cache
.
transpose
(
row_dim
,
col_dim
)
# If n is odd, we need to shrink the z_cache by one row
z_cache
=
z_cache
[...,
:(
n
//
2
),
:,
:]
# Move the 3rd quadrant of z into the
first_half_slicer
=
empty_slicer
(
z_cache
)
first_half_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
[
first_half_slicer
]
=
quadrant_3
# Get the fourth quadrant of z
quadrant_4
=
slice_tensor
(
z
,
half_n
,
None
,
row_dim
)
quadrant_4
=
slice_tensor
(
quadrant_4
,
half_n
,
None
,
col_dim
)
# Insert said quadrant into the rotated z-cache
quadrant_3_slicer
=
empty_slicer
(
z_cache
)
quadrant_3_slicer
[
col_dim
]
=
slice
(
half_n
,
None
)
z_cache
[
quadrant_3_slicer
]
=
quadrant_4
return
z_cache
# Initialize the z cache to the left half of z.
z_cache_shape
=
list
(
z
.
shape
)
z_cache_shape
[
col_dim
]
=
half_n
z_cache
=
z
.
new_zeros
(
z_cache_shape
)
z_cache_slicer
=
empty_slicer
(
z_cache
)
z_cache_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_cache
.
copy_
(
z
[
z_cache_slicer
])
z_cache_rotated
=
False
# We need to reorient the z-cache at the halfway point, and we
# don't want a single chunk to straddle that point. We contract one
# of the chunks in the middle to address that problem.
i_range
=
list
(
range
(
0
,
half_n
,
inplace_chunk_size
))
initial_offsets
=
[
i_2
-
i_1
for
i_1
,
i_2
in
zip
(
i_range
,
i_range
[
1
:]
+
[
half_n
])
]
after_half
=
list
(
range
(
half_n
,
n
,
inplace_chunk_size
))
after_half_offsets
=
[
inplace_chunk_size
for
_
in
after_half
]
combined_range_with_offsets
=
zip
(
i_range
+
after_half
,
initial_offsets
+
after_half_offsets
)
for
i
,
offset
in
combined_range_with_offsets
:
if
(
not
z_cache_rotated
and
i
>=
half_n
):
z_cache
=
flip_z_cache_
(
z_cache
,
z
)
z_cache_rotated
=
True
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
b_chunk_dim
,
)
mask_chunk
=
slice_tensor
(
mask
,
i
,
i
+
offset
,
b_chunk_dim
,
)
z_chunk_b
=
z_chunk_b
.
clone
()
if
(
b_chunk_dim
==
col_dim
):
z_chunk_b
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
else
:
# b_chunk_dim == row_dim
# In this case, the b-dimension (b_chunk_dim) is partially
# overwritten at the end of each iteration. We need to
# restore the missing component from the z-cache.
if
(
not
z_cache_rotated
):
z_chunk_slicer
=
empty_slicer
(
z_chunk_b
)
z_chunk_slicer
[
col_dim
]
=
slice
(
0
,
half_n
)
z_chunk_b
[
z_chunk_slicer
]
=
slice_tensor
(
z_cache
,
i
,
i
+
offset
,
row_dim
,
)
else
:
z_cache_offset
=
i
-
half_n
z_chunk_b
=
slice_tensor
(
z_cache
,
z_cache_offset
,
z_cache_offset
+
offset
,
row_dim
)
b_chunk
=
compute_projection
(
z_chunk_b
,
mask_chunk
,
a
=
False
,
chunked
=
False
)
del
z_chunk_b
x_chunk
=
torch
.
matmul
(
a
,
b_chunk
,
)
x_chunk
=
permute_final_dims
(
x_chunk
,
(
1
,
2
,
0
))
x_chunk
=
self
.
layer_norm_out
(
x_chunk
)
x_chunk
=
self
.
linear_z
(
x_chunk
)
# The g dimension (col_dim) is parallel to and ahead of the
# overwrites in z. We can extract the g chunk normally.
z_chunk_g
=
slice_tensor
(
z
,
i
,
i
+
offset
,
col_dim
)
g_chunk
=
self
.
linear_g
(
self
.
layer_norm_in
(
z_chunk_g
))
g_chunk
.
sigmoid_
()
del
z_chunk_g
x_chunk
*=
g_chunk
# Write the columns into z in-place
z_slicer
=
empty_slicer
(
z
)
z_slicer
[
col_dim
]
=
slice
(
i
,
i
+
offset
)
if
(
with_add
):
z
[
z_slicer
]
+=
x_chunk
else
:
z
[
z_slicer
]
=
x_chunk
else
:
b
=
compute_projection
(
z
,
mask
,
False
,
False
)
x
=
torch
.
matmul
(
a
,
b
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z
)
g
.
sigmoid_
()
x
*=
g
if
(
with_add
):
z
+=
x
else
:
z
=
x
return
z
def
forward
(
self
,
def
forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_inplace
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -71,57 +369,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -71,57 +369,46 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] output tensor
[*, N_res, N_res, C_z] output tensor
"""
"""
if
(
_inplace
):
x
=
self
.
_inference_forward
(
z
,
mask
,
inplace_chunk_size
=
_inplace_chunk_size
,
with_add
=
_add_with_inplace
,
)
return
x
if
mask
is
None
:
if
mask
is
None
:
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
z
.
new_ones
(
z
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
z
=
self
.
layer_norm_in
(
z
)
z
=
self
.
layer_norm_in
(
z
)
a
=
self
.
linear_a_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
a
=
mask
a
=
a
*
mask
a
=
a
*
self
.
sigmoid
(
self
.
linear_a_g
(
z
))
b
=
self
.
linear_b_p
(
z
)
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
a
=
a
*
self
.
linear_a_p
(
z
)
b
=
b
*
mask
b
=
mask
b
=
b
*
self
.
sigmoid
(
self
.
linear_b_g
(
z
))
b
=
b
*
self
.
linear_b_p
(
z
)
x
=
self
.
_combine_projections
(
a
,
b
)
x
=
self
.
_combine_projections
(
a
,
b
)
del
a
,
b
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
z
))
z
=
x
*
g
x
=
x
*
g
return
z
return
x
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationOutgoing
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 11.
Implements Algorithm 11.
"""
"""
def
_combine_projections
(
self
,
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
True
)
a
:
torch
.
Tensor
,
# [*, N_i, N_k, C]
b
:
torch
.
Tensor
,
# [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
0
,
1
)),
permute_final_dims
(
b
,
(
2
,
1
,
0
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
class
TriangleMultiplicationIncoming
(
TriangleMultiplicativeUpdate
):
"""
"""
Implements Algorithm 12.
Implements Algorithm 12.
"""
"""
def
_combine_projections
(
self
,
__init__
=
partialmethod
(
TriangleMultiplicativeUpdate
.
__init__
,
_outgoing
=
False
)
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
b
:
torch
.
Tensor
,
# [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
(
2
,
1
,
0
)),
permute_final_dims
(
b
,
(
2
,
0
,
1
)),
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
(
1
,
2
,
0
))
openfold/np/protein.py
View file @
6e66b218
...
@@ -140,12 +140,20 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
...
@@ -140,12 +140,20 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
residue_index
.
append
(
res
.
id
[
1
])
residue_index
.
append
(
res
.
id
[
1
])
b_factors
.
append
(
res_b_factors
)
b_factors
.
append
(
res_b_factors
)
parents
=
None
if
(
"PARENT"
in
pdb_str
):
for
l
in
pdb_str
.
split
(
"
\n
"
):
if
(
"PARENT"
in
l
and
not
"N/A"
in
l
):
parents
=
l
.
split
()[
1
:]
break
return
Protein
(
return
Protein
(
atom_positions
=
np
.
array
(
atom_positions
),
atom_positions
=
np
.
array
(
atom_positions
),
atom_mask
=
np
.
array
(
atom_mask
),
atom_mask
=
np
.
array
(
atom_mask
),
aatype
=
np
.
array
(
aatype
),
aatype
=
np
.
array
(
aatype
),
residue_index
=
np
.
array
(
residue_index
),
residue_index
=
np
.
array
(
residue_index
),
b_factors
=
np
.
array
(
b_factors
),
b_factors
=
np
.
array
(
b_factors
),
parents
=
parents
,
)
)
...
...
openfold/np/relax/amber_minimize.py
View file @
6e66b218
...
@@ -516,6 +516,9 @@ def run_pipeline(
...
@@ -516,6 +516,9 @@ def run_pipeline(
_check_residues_are_well_defined
(
prot
)
_check_residues_are_well_defined
(
prot
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
pdb_string
=
clean_protein
(
prot
,
checks
=
checks
)
# We keep the input around to restore metadata deleted by the relaxer
input_prot
=
prot
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
exclude_residues
or
[]
exclude_residues
=
set
(
exclude_residues
)
exclude_residues
=
set
(
exclude_residues
)
violations
=
np
.
inf
violations
=
np
.
inf
...
@@ -532,6 +535,11 @@ def run_pipeline(
...
@@ -532,6 +535,11 @@ def run_pipeline(
max_attempts
=
max_attempts
,
max_attempts
=
max_attempts
,
use_gpu
=
use_gpu
,
use_gpu
=
use_gpu
,
)
)
headers
=
protein
.
get_pdb_headers
(
prot
)
if
(
len
(
headers
)
>
0
):
ret
[
"min_pdb"
]
=
'
\n
'
.
join
([
'
\n
'
.
join
(
headers
),
ret
[
"min_pdb"
]])
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
prot
=
protein
.
from_pdb_string
(
ret
[
"min_pdb"
])
if
place_hydrogens_every_iteration
:
if
place_hydrogens_every_iteration
:
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
pdb_string
=
clean_protein
(
prot
,
checks
=
True
)
...
...
openfold/utils/exponential_moving_average.py
View file @
6e66b218
...
@@ -58,7 +58,8 @@ class ExponentialMovingAverage:
...
@@ -58,7 +58,8 @@ class ExponentialMovingAverage:
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
self
.
params
=
state_dict
[
"params"
]
for
k
in
state_dict
[
"params"
].
keys
():
self
.
params
[
k
]
=
state_dict
[
"params"
][
k
].
clone
()
self
.
decay
=
state_dict
[
"decay"
]
self
.
decay
=
state_dict
[
"decay"
]
def
state_dict
(
self
)
->
OrderedDict
:
def
state_dict
(
self
)
->
OrderedDict
:
...
...
openfold/utils/loss.py
View file @
6e66b218
...
@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels):
...
@@ -43,9 +43,15 @@ def softmax_cross_entropy(logits, labels):
def
sigmoid_cross_entropy
(
logits
,
labels
):
def
sigmoid_cross_entropy
(
logits
,
labels
):
log_p
=
torch
.
log
(
torch
.
sigmoid
(
logits
))
logits_dtype
=
logits
.
dtype
log_not_p
=
torch
.
log
(
torch
.
sigmoid
(
-
logits
))
logits
=
logits
.
double
()
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
labels
=
labels
.
double
()
log_p
=
torch
.
nn
.
functional
.
logsigmoid
(
logits
)
# log_p = torch.log(torch.sigmoid(logits))
log_not_p
=
torch
.
nn
.
functional
.
logsigmoid
(
-
1
*
logits
)
# log_not_p = torch.log(torch.sigmoid(-logits))
loss
=
(
-
1.
*
labels
)
*
log_p
-
(
1.
-
labels
)
*
log_not_p
loss
=
loss
.
to
(
dtype
=
logits_dtype
)
return
loss
return
loss
...
@@ -1472,13 +1478,13 @@ def experimentally_resolved_loss(
...
@@ -1472,13 +1478,13 @@ def experimentally_resolved_loss(
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=-
1
)
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
loss
/
(
eps
+
torch
.
sum
(
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
)))
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
*
(
loss
=
loss
*
(
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
(
resolution
>=
min_resolution
)
&
(
resolution
<=
max_resolution
)
)
)
loss
=
torch
.
mean
(
loss
)
loss
=
torch
.
mean
(
loss
)
return
loss
return
loss
...
...
openfold/utils/tensor_utils.py
View file @
6e66b218
...
@@ -19,6 +19,17 @@ import torch.nn as nn
...
@@ -19,6 +19,17 @@ import torch.nn as nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
def
add
(
m1
,
m2
,
inplace
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
not
inplace
):
m1
=
m1
+
m2
else
:
m1
+=
m2
return
m1
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]):
zero_index
=
-
1
*
len
(
inds
)
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
...
...
run_pretrained_openfold.py
View file @
6e66b218
...
@@ -110,7 +110,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
...
@@ -110,7 +110,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
# Prep protein metadata
# Prep protein metadata
template_domain_names
=
[]
template_domain_names
=
[]
template_chain_index
=
None
template_chain_index
=
None
if
(
feature_processor
.
config
.
common
.
use_templates
):
if
(
feature_processor
.
config
.
common
.
use_templates
and
"template_domain_names"
in
feature_dict
):
template_domain_names
=
[
template_domain_names
=
[
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
t
.
decode
(
"utf-8"
)
for
t
in
feature_dict
[
"template_domain_names"
]
]
]
...
...
scripts/precompute_alignments.py
View file @
6e66b218
...
@@ -227,7 +227,7 @@ if __name__ == "__main__":
...
@@ -227,7 +227,7 @@ if __name__ == "__main__":
)
)
add_data_args
(
parser
)
add_data_args
(
parser
)
parser
.
add_argument
(
parser
.
add_argument
(
"--raise_errors"
,
type
=
bool
,
default
=
False
,
"--raise_errors"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to crash on parsing errors"
help
=
"Whether to crash on parsing errors"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
tests/test_triangular_multiplicative_update.py
View file @
6e66b218
...
@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -30,12 +30,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def
test_shape
(
self
):
def
test_shape
(
self
):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
11
c
=
11
outgoing
=
True
tm
=
TriangleMultiplicationOutgoing
(
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c_z
,
c
,
c
,
outgoing
,
)
)
n_res
=
consts
.
c_z
n_res
=
consts
.
c_z
...
@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -94,9 +92,10 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
_inference_mode
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
m
ax
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
m
ean
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
def
test_tri_mul_out_compare
(
self
):
...
@@ -106,6 +105,40 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -106,6 +105,40 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def
test_tri_mul_in_compare
(
self
):
def
test_tri_mul_in_compare
(
self
):
self
.
_tri_mul_compare
(
incoming
=
True
)
self
.
_tri_mul_compare
(
incoming
=
True
)
def
_tri_mul_inference_mode
(
self
,
incoming
=
False
):
n_res
=
consts
.
n_res
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
)
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
_inference_mode
=
False
,
).
cpu
()
# This has to come second because inference mode is in-place
out_inference_mode
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
_inference_mode
=
True
,
_inplace_chunk_size
=
2
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_stock
-
out_inference_mode
))
<
consts
.
eps
)
def
test_tri_mul_out_inference
(
self
):
self
.
_tri_mul_inference_mode
()
def
test_tri_mul_in_inference
(
self
):
self
.
_tri_mul_inference_mode
(
incoming
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment