Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4bd1b4d5
"vscode:/vscode.git/clone" did not exist on "cfb4c19cac1a4e2358edc02a6baff89a1e4377c0"
Commit
4bd1b4d5
authored
Apr 28, 2022
by
Gustaf Ahdritz
Browse files
Work on multimer continues
parent
54164fe8
Changes
42
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2327 additions
and
689 deletions
+2327
-689
openfold/__init__.py
openfold/__init__.py
+1
-0
openfold/config.py
openfold/config.py
+84
-6
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+152
-145
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+16
-4
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+303
-0
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+21
-7
openfold/data/feature_processing_multimer.py
openfold/data/feature_processing_multimer.py
+13
-9
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+135
-0
openfold/data/msa_identifiers.py
openfold/data/msa_identifiers.py
+2
-3
openfold/data/msa_pairing.py
openfold/data/msa_pairing.py
+69
-212
openfold/data/parsers.py
openfold/data/parsers.py
+20
-3
openfold/data/templates.py
openfold/data/templates.py
+82
-60
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+6
-4
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+22
-19
openfold/model/embedders.py
openfold/model/embedders.py
+583
-13
openfold/model/model.py
openfold/model/model.py
+79
-111
openfold/model/structure_module.py
openfold/model/structure_module.py
+215
-92
openfold/np/protein.py
openfold/np/protein.py
+1
-1
openfold/utils/all_atom_multimer.py
openfold/utils/all_atom_multimer.py
+493
-0
openfold/utils/argparse_utils.py
openfold/utils/argparse_utils.py
+30
-0
No files found.
openfold/__init__.py
View file @
4bd1b4d5
from
.
import
model
from
.
import
model
from
.
import
utils
from
.
import
utils
from
.
import
data
from
.
import
np
from
.
import
np
from
.
import
resources
from
.
import
resources
...
...
openfold/config.py
View file @
4bd1b4d5
...
@@ -75,7 +75,17 @@ def model_config(name, train=False, low_prec=False):
...
@@ -75,7 +75,17 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
model
.
update
(
multimer_model_config_update
)
c
.
globals
.
is_multimer
=
True
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
])
else
:
else
:
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
...
@@ -276,6 +286,7 @@ config = mlc.ConfigDict(
...
@@ -276,6 +286,7 @@ config = mlc.ConfigDict(
"c_e"
:
c_e
,
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"eps"
:
eps
,
"is_multimer"
:
False
,
},
},
"model"
:
{
"model"
:
{
"_mask_trans"
:
False
,
"_mask_trans"
:
False
,
...
@@ -335,6 +346,7 @@ config = mlc.ConfigDict(
...
@@ -335,6 +346,7 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
# 1e-6,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
False
,
},
},
"extra_msa"
:
{
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"extra_msa_embedder"
:
{
...
@@ -496,10 +508,76 @@ config = mlc.ConfigDict(
...
@@ -496,10 +508,76 @@ config = mlc.ConfigDict(
}
}
)
)
multimer_model_config_update
=
mlc
.
ConfigDict
(
multimer_model_config_update
=
{
"relative_encoding"
:
{
"input_embedder"
:
{
"enabled"
:
True
,
"tf_dim"
:
21
,
"msa_dim"
:
49
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"max_relative_idx"
:
32
,
}
"use_chain_relative"
:
True
,
)
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_out"
:
64
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
}
openfold/data/data_pipeline.py
View file @
4bd1b4d5
This diff is collapsed.
Click to expand it.
openfold/data/data_transforms.py
View file @
4bd1b4d5
...
@@ -428,10 +428,16 @@ def make_hhblits_profile(protein):
...
@@ -428,10 +428,16 @@ def make_hhblits_profile(protein):
@
curry1
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
,
seed
):
"""Create data for BERT on raw MSA."""
"""Create data for BERT on raw MSA."""
device
=
protein
[
"msa"
].
device
# Add a random amino acid uniformly.
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
)
random_aa
=
torch
.
tensor
(
[
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
,
device
=
device
)
categorical_probs
=
(
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
config
.
uniform_prob
*
random_aa
...
@@ -449,11 +455,17 @@ def make_masked_msa(protein, config, replace_fraction):
...
@@ -449,11 +455,17 @@ def make_masked_msa(protein, config, replace_fraction):
)
)
assert
mask_prob
>=
0.0
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
categorical_probs
,
pad_shapes
,
value
=
mask_prob
,
)
)
sh
=
protein
[
"msa"
].
shape
sh
=
protein
[
"msa"
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
sample
=
torch
.
rand
(
sh
,
device
=
device
,
generator
=
g
)
mask_position
=
sample
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
...
...
openfold/data/data_transforms_multimer.py
0 → 100644
View file @
4bd1b4d5
from
typing
import
Sequence
import
torch
from
openfold.data.data_transforms
import
curry1
from
openfold.utils.tensor_utils
import
masked_mean
def
gumbel_noise
(
shape
:
Sequence
[
int
],
device
:
torch
.
device
,
eps
=
1e-6
,
generator
=
None
,
)
->
torch
.
Tensor
:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise
=
torch
.
rand
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
,
generator
=
generator
)
gumbel
=
-
torch
.
log
(
-
torch
.
log
(
uniform_noise
+
eps
)
+
eps
)
return
gumbel
def
gumbel_max_sample
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
nn
.
functional
.
one_hot
(
torch
.
argmax
(
logits
+
z
,
dim
=-
1
),
logits
.
shape
[
-
1
],
)
def
gumbel_argsort_sample_idx
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
argsort
(
logits
+
z
,
dim
=-
1
,
descending
=
True
)
@
curry1
def
make_masked_msa
(
batch
,
config
,
replace_fraction
,
seed
,
eps
=
1e-6
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
Tensor
(
[
0.05
]
*
20
+
[
0.
,
0.
],
device
=
batch
[
'msa'
].
device
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
batch
[
'msa_profile'
]
+
config
.
same_prob
*
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
[
0
,
1
],
value
=
mask_prob
)
sh
=
batch
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
,
device
=
batch
[
'msa'
].
device
)
<
replace_fraction
mask_position
*=
batch
[
'msa_mask'
].
to
(
mask_position
.
dtype
)
logits
=
torch
.
log
(
categorical_probs
+
eps
)
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
bert_msa
=
gumbel_max_sample
(
logits
,
generator
=
g
)
bert_msa
=
torch
.
where
(
mask_position
,
torch
.
argmax
(
bert_msa
,
dim
=-
1
),
batch
[
'msa'
]
)
bert_msa
*=
batch
[
'msa_mask'
].
to
(
bert_msa
.
dtype
)
# Mix real and masked MSA.
if
'bert_mask'
in
batch
:
batch
[
'bert_mask'
]
*=
mask_position
.
to
(
torch
.
float32
)
else
:
batch
[
'bert_mask'
]
=
mask_position
.
to
(
torch
.
float32
)
batch
[
'true_msa'
]
=
batch
[
'msa'
]
batch
[
'msa'
]
=
bert_msa
return
batch
@
curry1
def
nearest_neighbor_clusters
(
batch
,
gap_agreement_weight
=
0.
):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device
=
batch
[
"msa_mask"
].
device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights
=
torch
.
Tensor
(
[
1.
]
*
21
+
[
gap_agreement_weight
]
+
[
0.
],
device
=
device
,
)
msa_mask
=
batch
[
'msa_mask'
]
msa_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
extra_mask
=
batch
[
'extra_msa_mask'
]
extra_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'extra_msa'
],
23
)
msa_one_hot_masked
=
msa_mask
[:,
:,
None
]
*
msa_one_hot
extra_one_hot_masked
=
extra_mask
[:,
:,
None
]
*
extra_one_hot
agreement
=
torch
.
einsum
(
'mrc, nrc->nm'
,
extra_one_hot_masked
,
weights
*
msa_one_hot_masked
)
cluster_assignment
=
torch
.
nn
.
functional
.
softmax
(
1e3
*
agreement
,
dim
=
0
)
cluster_assignment
*=
torch
.
einsum
(
'mr, nr->mn'
,
msa_mask
,
extra_mask
)
cluster_count
=
torch
.
sum
(
cluster_assignment
,
dim
=-
1
)
cluster_count
+=
1.
# We always include the sequence itself.
msa_sum
=
torch
.
einsum
(
'nm, mrc->nrc'
,
cluster_assignment
,
extra_one_hot_masked
)
msa_sum
+=
msa_one_hot_masked
cluster_profile
=
msa_sum
/
cluster_count
[:,
None
,
None
]
extra_deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
deletion_matrix
=
batch
[
'deletion_matrix'
]
del_sum
=
torch
.
einsum
(
'nm, mc->nc'
,
cluster_assignment
,
extra_mask
*
extra_deletion_matrix
)
del_sum
+=
deletion_matrix
# Original sequence.
cluster_deletion_mean
=
del_sum
/
cluster_count
[:,
None
]
batch
[
'cluster_profile'
]
=
cluster_profile
batch
[
'cluster_deletion_mean'
]
=
cluster_deletion_mean
return
batch
def
create_target_feat
(
batch
):
"""Create the target features"""
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
21
).
to
(
torch
.
float32
)
return
batch
def
create_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
device
=
batch
[
"msa"
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
deletion_matrix
=
batch
[
'deletion_matrix'
]
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
deletion_mean_value
=
(
torch
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
pi
)
)[...,
None
]
msa_feat
=
torch
.
cat
(
[
msa_1hot
,
has_deletion
,
deletion_value
,
batch
[
'cluster_profile'
],
deletion_mean_value
],
dim
=-
1
,
)
batch
[
"msa_feat"
]
=
msa_feat
return
batch
def
build_extra_msa_feat
(
batch
):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa
=
batch
[
'extra_msa'
]
deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
extra_msa
,
23
)
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
)
extra_msa_mask
=
batch
[
'extra_msa_mask'
]
catted
=
torch
.
cat
([
msa_1hot
,
has_deletion
,
deletion_value
],
dim
=-
1
)
return
catted
@
curry1
def
sample_msa
(
batch
,
max_seq
,
max_extra_msa_seq
,
seed
,
inf
=
1e6
):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
# Sample uniformly among sequences with at least one non-masked position.
logits
=
(
torch
.
clamp
(
torch
.
sum
(
batch
[
'msa_mask'
],
dim
=-
1
),
0.
,
1.
)
-
1.
)
*
inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if
'cluster_bias_mask'
not
in
batch
:
cluster_bias_mask
=
torch
.
nn
.
functional
.
pad
(
batch
[
'msa'
].
new_zeros
(
batch
[
'msa'
].
shape
[
0
]
-
1
),
(
1
,
0
),
value
=
1.
)
else
:
cluster_bias_mask
=
batch
[
'cluster_bias_mask'
]
logits
+=
cluster_bias_mask
*
inf
index_order
=
gumbel_argsort_sample_idx
(
logits
,
generator
=
g
)
sel_idx
=
index_order
[:
max_seq
]
extra_idx
=
index_order
[
max_seq
:][:
max_extra_msa_seq
]
for
k
in
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'bert_mask'
]:
if
k
in
batch
:
batch
[
'extra_'
+
k
]
=
batch
[
k
][
extra_idx
]
batch
[
k
]
=
batch
[
k
][
sel_idx
]
return
batch
def
make_msa_profile
(
batch
):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch
[
"msa_profile"
]
=
masked_mean
(
batch
[
'msa_mask'
][...,
None
],
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
),
dim
=-
3
,
)
return
batch
openfold/data/feature_pipeline.py
View file @
4bd1b4d5
...
@@ -20,7 +20,7 @@ import ml_collections
...
@@ -20,7 +20,7 @@ import ml_collections
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
openfold.data
import
input_pipeline
from
openfold.data
import
input_pipeline
,
input_pipeline_multimer
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
...
@@ -73,8 +73,10 @@ def np_example_to_features(
...
@@ -73,8 +73,10 @@ def np_example_to_features(
np_example
:
FeatureDict
,
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
mode
:
str
,
is_multimer
:
bool
=
False
):
):
np_example
=
dict
(
np_example
)
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
...
@@ -87,11 +89,18 @@ def np_example_to_features(
...
@@ -87,11 +89,18 @@ def np_example_to_features(
np_example
=
np_example
,
features
=
feature_names
np_example
=
np_example
,
features
=
feature_names
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
(
not
is_multimer
):
features
=
input_pipeline
.
process_tensors_from_config
(
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
tensor_dict
,
cfg
.
common
,
cfg
.
common
,
cfg
[
mode
],
cfg
[
mode
],
)
)
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
@@ -107,9 +116,14 @@ class FeaturePipeline:
...
@@ -107,9 +116,14 @@ class FeaturePipeline:
self
,
self
,
raw_features
:
FeatureDict
,
raw_features
:
FeatureDict
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
is_multimer
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
if
(
is_multimer
and
mode
!=
"predict"
):
raise
ValueError
(
"Multimer mode is not currently trainable"
)
return
np_example_to_features
(
return
np_example_to_features
(
np_example
=
raw_features
,
np_example
=
raw_features
,
config
=
self
.
config
,
config
=
self
.
config
,
mode
=
mode
,
mode
=
mode
,
is_multimer
=
is_multimer
,
)
)
openfold/data/
multimer_
feature_processing.py
→
openfold/data/feature_processing
_multimer
.py
View file @
4bd1b4d5
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Feature processing logic for multimer data pipeline."""
"""Feature processing logic for multimer data pipeline."""
from
typing
import
Iterable
,
MutableMapping
,
List
from
typing
import
Iterable
,
MutableMapping
,
List
,
Mapping
from
openfold.data
import
msa_pairing
from
openfold.data
import
msa_pairing
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
...
@@ -49,13 +49,11 @@ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
...
@@ -49,13 +49,11 @@ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
def
pair_and_merge
(
def
pair_and_merge
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
is_prokaryote
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Runs processing on features to augment, pair and merge.
"""Runs processing on features to augment, pair and merge.
Args:
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns:
Returns:
A dictionary of features.
A dictionary of features.
...
@@ -69,7 +67,8 @@ def pair_and_merge(
...
@@ -69,7 +67,8 @@ def pair_and_merge(
if
pair_msa_sequences
:
if
pair_msa_sequences
:
np_chains_list
=
msa_pairing
.
create_paired_features
(
np_chains_list
=
msa_pairing
.
create_paired_features
(
chains
=
np_chains_list
,
prokaryotic
=
is_prokaryote
)
chains
=
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
crop_chains
(
np_chains_list
=
crop_chains
(
np_chains_list
,
np_chains_list
,
...
@@ -175,6 +174,7 @@ def process_final(
...
@@ -175,6 +174,7 @@ def process_final(
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_filter_features
(
np_example
)
np_example
=
_filter_features
(
np_example
)
return
np_example
return
np_example
...
@@ -210,19 +210,23 @@ def _filter_features(
...
@@ -210,19 +210,23 @@ def _filter_features(
def
process_unmerged_features
(
def
process_unmerged_features
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]):
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]
):
"""Postprocessing stage for per-chain features before merging."""
"""Postprocessing stage for per-chain features before merging."""
num_chains
=
len
(
all_chain_features
)
num_chains
=
len
(
all_chain_features
)
for
chain_features
in
all_chain_features
.
values
():
for
chain_features
in
all_chain_features
.
values
():
# Convert deletion matrices to float.
# Convert deletion matrices to float.
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
if
'deletion_matrix_int_all_seq'
in
chain_features
:
if
'deletion_matrix_int_all_seq'
in
chain_features
:
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_matrix'
],
axis
=
0
)
chain_features
[
'deletion_matrix'
],
axis
=
0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
...
...
openfold/data/input_pipeline_multimer.py
0 → 100644
View file @
4bd1b4d5
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
from
openfold.data
import
(
data_transforms
,
data_transforms_multimer
,
)
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
]
if
(
common_cfg
.
use_templates
):
transforms
.
extend
([
data_transforms
.
make_pseudo_beta
(
"template_"
),
])
return
transforms
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
msa_seed
=
None
if
(
not
common_cfg
.
resample_msa_in_recycling
):
msa_seed
=
ensemble_seed
transforms
.
append
(
data_transforms_multimer
.
sample_msa
(
max_msa_clusters
,
max_extra_msa
,
seed
=
msa_seed
,
)
)
if
"masked_msa"
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms_multimer
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
,
seed
=
(
msa_seed
+
1
)
if
msa_seed
else
None
,
)
)
transforms
.
append
(
data_transforms_multimer
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms_multimer
.
create_msa_feat
)
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
()
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
,
)
fn
=
compose
(
fns
)
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
openfold/data/msa_identifiers.py
View file @
4bd1b4d5
...
@@ -48,7 +48,6 @@ _UNIPROT_PATTERN = re.compile(
...
@@ -48,7 +48,6 @@ _UNIPROT_PATTERN = re.compile(
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Identifiers
:
class
Identifiers
:
uniprot_accession_id
:
str
=
''
species_id
:
str
=
''
species_id
:
str
=
''
...
@@ -69,8 +68,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
...
@@ -69,8 +68,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
if
matches
:
if
matches
:
return
Identifiers
(
return
Identifiers
(
uniprot_accession
_id
=
matches
.
group
(
'
Accession
Identifier'
)
,
species
_id
=
matches
.
group
(
'
Species
Identifier'
)
species_id
=
matches
.
group
(
'SpeciesIdentifier'
)
)
)
return
Identifiers
()
return
Identifiers
()
...
...
openfold/data/msa_pairing.py
View file @
4bd1b4d5
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
collections
import
collections
import
functools
import
functools
import
string
import
string
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Sequence
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Sequence
,
Mapping
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
...
@@ -27,12 +27,6 @@ from openfold.np import residue_constants
...
@@ -27,12 +27,6 @@ from openfold.np import residue_constants
# TODO: This stuff should probably also be in a config
# TODO: This stuff should probably also be in a config
ALPHA_ACCESSION_ID_MAP
=
{
x
:
y
for
y
,
x
in
enumerate
(
string
.
ascii_uppercase
)}
ALPHANUM_ACCESSION_ID_MAP
=
{
chr
:
num
for
num
,
chr
in
enumerate
(
string
.
ascii_uppercase
+
string
.
digits
)
}
# A-Z,0-9
NUM_ACCESSION_ID_MAP
=
{
str
(
x
):
x
for
x
in
range
(
10
)}
# 0-9
MSA_GAP_IDX
=
residue_constants
.
restypes_with_x_and_gap
.
index
(
'-'
)
MSA_GAP_IDX
=
residue_constants
.
restypes_with_x_and_gap
.
index
(
'-'
)
SEQUENCE_GAP_CUTOFF
=
0.5
SEQUENCE_GAP_CUTOFF
=
0.5
SEQUENCE_SIMILARITY_CUTOFF
=
0.9
SEQUENCE_SIMILARITY_CUTOFF
=
0.9
...
@@ -61,14 +55,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
...
@@ -61,14 +55,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
def
create_paired_features
(
def
create_paired_features
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]],
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]],
prokaryotic
:
bool
,
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Returns the original chains with paired NUM_SEQ features.
"""Returns the original chains with paired NUM_SEQ features.
Args:
Args:
chains: A list of feature dictionaries for each chain.
chains: A list of feature dictionaries for each chain.
prokaryotic: Whether the target complex is from a prokaryotic organism.
Used to determine the distance metric for pairing.
Returns:
Returns:
A list of feature dictionaries with sequence features including only
A list of feature dictionaries with sequence features including only
...
@@ -81,8 +72,7 @@ def create_paired_features(
...
@@ -81,8 +72,7 @@ def create_paired_features(
return
chains
return
chains
else
:
else
:
updated_chains
=
[]
updated_chains
=
[]
paired_chains_to_paired_row_indices
=
pair_sequences
(
paired_chains_to_paired_row_indices
=
pair_sequences
(
chains
)
chains
,
prokaryotic
)
paired_rows
=
reorder_paired_rows
(
paired_rows
=
reorder_paired_rows
(
paired_chains_to_paired_row_indices
)
paired_chains_to_paired_row_indices
)
...
@@ -117,8 +107,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
...
@@ -117,8 +107,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
num_res
=
feature
.
shape
[
1
]
num_res
=
feature
.
shape
[
1
]
padding
=
MSA_PAD_VALUES
[
feature_name
]
*
np
.
ones
([
1
,
num_res
],
padding
=
MSA_PAD_VALUES
[
feature_name
]
*
np
.
ones
([
1
,
num_res
],
feature
.
dtype
)
feature
.
dtype
)
elif
feature_name
in
(
'msa_uniprot_accession_identifiers_all_seq'
,
elif
feature_name
==
'msa_species_identifiers_all_seq'
:
'msa_species_identifiers_all_seq'
):
padding
=
[
b
''
]
padding
=
[
b
''
]
else
:
else
:
return
feature
return
feature
...
@@ -136,11 +125,9 @@ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
...
@@ -136,11 +125,9 @@ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
msa_df
=
pd
.
DataFrame
({
msa_df
=
pd
.
DataFrame
({
'msa_species_identifiers'
:
'msa_species_identifiers'
:
chain_features
[
'msa_species_identifiers_all_seq'
],
chain_features
[
'msa_species_identifiers_all_seq'
],
'msa_uniprot_accession_identifiers'
:
chain_features
[
'msa_uniprot_accession_identifiers_all_seq'
],
'msa_row'
:
'msa_row'
:
np
.
arange
(
len
(
np
.
arange
(
len
(
chain_features
[
'msa_
uniprot_accession
_identifiers_all_seq'
])),
chain_features
[
'msa_
species
_identifiers_all_seq'
])),
'msa_similarity'
:
per_seq_similarity
,
'msa_similarity'
:
per_seq_similarity
,
'gap'
:
per_seq_gap
'gap'
:
per_seq_gap
})
})
...
@@ -155,139 +142,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
...
@@ -155,139 +142,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
return
species_lookup
return
species_lookup
@
functools
.
lru_cache
(
maxsize
=
65536
)
def
encode_accession
(
accession_id
:
str
)
->
int
:
"""Map accession codes to the serial order in which they were assigned."""
alpha
=
ALPHA_ACCESSION_ID_MAP
# A-Z
alphanum
=
ALPHANUM_ACCESSION_ID_MAP
# A-Z,0-9
num
=
NUM_ACCESSION_ID_MAP
# 0-9
coding
=
0
# This is based on the uniprot accession id format
# https://www.uniprot.org/help/accession_numbers
if
accession_id
[
0
]
in
{
'O'
,
'P'
,
'Q'
}:
bases
=
(
alpha
,
num
,
alphanum
,
alphanum
,
alphanum
,
num
)
elif
len
(
accession_id
)
==
6
:
bases
=
(
alpha
,
num
,
alpha
,
alphanum
,
alphanum
,
num
)
elif
len
(
accession_id
)
==
10
:
bases
=
(
alpha
,
num
,
alpha
,
alphanum
,
alphanum
,
num
,
alpha
,
alphanum
,
alphanum
,
num
)
product
=
1
for
place
,
base
in
zip
(
reversed
(
accession_id
),
reversed
(
bases
)):
coding
+=
base
[
place
]
*
product
product
*=
len
(
base
)
return
coding
def
_calc_id_diff
(
id_a
:
bytes
,
id_b
:
bytes
)
->
int
:
return
abs
(
encode_accession
(
id_a
.
decode
())
-
encode_accession
(
id_b
.
decode
()))
def
_find_all_accession_matches
(
accession_id_lists
:
List
[
List
[
bytes
]],
diff_cutoff
:
int
=
20
)
->
List
[
List
[
Any
]]:
"""Finds accession id matches across the chains based on their difference."""
all_accession_tuples
=
[]
current_tuple
=
[]
tokens_used_in_answer
=
set
()
def
_matches_all_in_current_tuple
(
inp
:
bytes
,
diff_cutoff
:
int
)
->
bool
:
return
all
((
_calc_id_diff
(
s
,
inp
)
<
diff_cutoff
for
s
in
current_tuple
))
def
_all_tokens_not_used_before
()
->
bool
:
return
all
((
s
not
in
tokens_used_in_answer
for
s
in
current_tuple
))
def
dfs
(
level
,
accession_id
,
diff_cutoff
=
diff_cutoff
)
->
None
:
if
level
==
len
(
accession_id_lists
)
-
1
:
if
_all_tokens_not_used_before
():
all_accession_tuples
.
append
(
list
(
current_tuple
))
for
s
in
current_tuple
:
tokens_used_in_answer
.
add
(
s
)
return
if
level
==
-
1
:
new_list
=
accession_id_lists
[
level
+
1
]
else
:
new_list
=
[(
_calc_id_diff
(
accession_id
,
s
),
s
)
for
s
in
accession_id_lists
[
level
+
1
]]
new_list
=
sorted
(
new_list
)
new_list
=
[
s
for
d
,
s
in
new_list
]
for
s
in
new_list
:
if
(
_matches_all_in_current_tuple
(
s
,
diff_cutoff
)
and
s
not
in
tokens_used_in_answer
):
current_tuple
.
append
(
s
)
dfs
(
level
+
1
,
s
)
current_tuple
.
pop
()
dfs
(
-
1
,
''
)
return
all_accession_tuples
def
_accession_row
(
msa_df
:
pd
.
DataFrame
,
accession_id
:
bytes
)
->
pd
.
Series
:
matched_df
=
msa_df
[
msa_df
.
msa_uniprot_accession_identifiers
==
accession_id
]
return
matched_df
.
iloc
[
0
]
def
_match_rows_by_genetic_distance
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
],
cutoff
:
int
=
20
)
->
List
[
List
[
int
]]:
"""Finds MSA sequence pairings across chains within a genetic distance cutoff.
The genetic distance between two sequences is approximated by taking the
difference in their UniProt accession ids.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species. If species is missing for a chain, the
dataframe is set to None.
cutoff: the genetic distance cutoff.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
num_examples
=
len
(
this_species_msa_dfs
)
# N
accession_id_lists
=
[]
# M
match_index_to_chain_index
=
{}
for
chain_index
,
species_df
in
enumerate
(
this_species_msa_dfs
):
if
species_df
is
not
None
:
accession_id_lists
.
append
(
list
(
species_df
.
msa_uniprot_accession_identifiers
.
values
))
# Keep track of which of the this_species_msa_dfs are not None.
match_index_to_chain_index
[
len
(
accession_id_lists
)
-
1
]
=
chain_index
all_accession_id_matches
=
_find_all_accession_matches
(
accession_id_lists
,
cutoff
)
# [k, M]
all_paired_msa_rows
=
[]
# [k, N]
for
accession_id_match
in
all_accession_id_matches
:
paired_msa_rows
=
[]
for
match_index
,
accession_id
in
enumerate
(
accession_id_match
):
# Map back to chain index.
chain_index
=
match_index_to_chain_index
[
match_index
]
seq_series
=
_accession_row
(
this_species_msa_dfs
[
chain_index
],
accession_id
)
if
(
seq_series
.
msa_similarity
>
SEQUENCE_SIMILARITY_CUTOFF
or
seq_series
.
gap
>
SEQUENCE_GAP_CUTOFF
):
continue
else
:
paired_msa_rows
.
append
(
seq_series
.
msa_row
)
# If a sequence is skipped based on sequence similarity to the respective
# target sequence or a gap cuttoff, the lengths of accession_id_match and
# paired_msa_rows will be different. Skip this match.
if
len
(
paired_msa_rows
)
==
len
(
accession_id_match
):
paired_and_non_paired_msa_rows
=
np
.
array
([
-
1
]
*
num_examples
)
matched_chain_indices
=
list
(
match_index_to_chain_index
.
values
())
paired_and_non_paired_msa_rows
[
matched_chain_indices
]
=
paired_msa_rows
all_paired_msa_rows
.
append
(
list
(
paired_and_non_paired_msa_rows
))
return
all_paired_msa_rows
def
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
]
def
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
]
)
->
List
[
List
[
int
]]:
)
->
List
[
List
[
int
]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
"""Finds MSA sequence pairings across chains based on sequence similarity.
...
@@ -324,8 +178,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
...
@@ -324,8 +178,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
return
all_paired_msa_rows
return
all_paired_msa_rows
def
pair_sequences
(
examples
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
def
pair_sequences
(
prokaryotic
:
bool
)
->
Dict
[
int
,
np
.
ndarray
]:
examples
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
)
->
Dict
[
int
,
np
.
ndarray
]:
"""Returns indices for paired MSA sequences across chains."""
"""Returns indices for paired MSA sequences across chains."""
num_examples
=
len
(
examples
)
num_examples
=
len
(
examples
)
...
@@ -367,22 +222,6 @@ def pair_sequences(examples: List[Mapping[str, np.ndarray]],
...
@@ -367,22 +222,6 @@ def pair_sequences(examples: List[Mapping[str, np.ndarray]],
isinstance
(
species_df
,
pd
.
DataFrame
)])
>
600
):
isinstance
(
species_df
,
pd
.
DataFrame
)])
>
600
):
continue
continue
# In prokaryotes (and some eukaryotes), interacting genes are often
# co-located on the chromosome into operons. Because of that we can assume
# that if two proteins' intergenic distance is less than a threshold, they
# two proteins will form an an interacting pair.
# In most eukaryotes, a single protein's MSA can contain many paralogs.
# Two genes may interact even if they are not close by genomic distance.
# In case of eukaryotes, some methods pair MSA sequences using sequence
# similarity method.
# See Jinbo Xu's work:
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
if
prokaryotic
:
paired_msa_rows
=
_match_rows_by_genetic_distance
(
this_species_msa_dfs
)
if
not
paired_msa_rows
:
continue
else
:
paired_msa_rows
=
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
)
paired_msa_rows
=
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
)
all_paired_msa_rows
.
extend
(
paired_msa_rows
)
all_paired_msa_rows
.
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
[
species_dfs_present
].
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
[
species_dfs_present
].
extend
(
paired_msa_rows
)
...
@@ -431,13 +270,19 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
...
@@ -431,13 +270,19 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
def
_correct_post_merged_feats
(
def
_correct_post_merged_feats
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_chains_list
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
np_chains_list
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Adds features that need to be computed/recomputed post merging."""
"""Adds features that need to be computed/recomputed post merging."""
np_example
[
'seq_length'
]
=
np
.
asarray
(
np_example
[
'aatype'
].
shape
[
0
],
num_res
=
np_example
[
'aatype'
].
shape
[
0
]
dtype
=
np
.
int32
)
np_example
[
'seq_length'
]
=
np
.
asarray
(
np_example
[
'num_alignments'
]
=
np
.
asarray
(
np_example
[
'msa'
].
shape
[
0
],
[
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
np_example
[
'num_alignments'
]
=
np
.
asarray
(
np_example
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
if
not
pair_msa_sequences
:
if
not
pair_msa_sequences
:
# Generate a bias that is 1 for the first row of every block in the
# Generate a bias that is 1 for the first row of every block in the
...
@@ -449,29 +294,41 @@ def _correct_post_merged_feats(
...
@@ -449,29 +294,41 @@ def _correct_post_merged_feats(
mask
=
np
.
zeros
(
chain
[
'msa'
].
shape
[
0
])
mask
=
np
.
zeros
(
chain
[
'msa'
].
shape
[
0
])
mask
[
0
]
=
1
mask
[
0
]
=
1
cluster_bias_masks
.
append
(
mask
)
cluster_bias_masks
.
append
(
mask
)
np_example
[
'cluster_bias_mask'
]
=
np
.
concatenate
(
cluster_bias_masks
)
np_example
[
'cluster_bias_mask'
]
=
np
.
concatenate
(
cluster_bias_masks
)
# Initialize Bert mask with masked out off diagonals.
# Initialize Bert mask with masked out off diagonals.
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
msa_masks
=
[
for
x
in
np_chains_list
]
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
np_example
[
'bert_mask'
]
=
block_diag
(
np_example
[
'bert_mask'
]
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
*
msa_masks
,
pad_value
=
0
)
else
:
else
:
np_example
[
'cluster_bias_mask'
]
=
np
.
zeros
(
np_example
[
'msa'
].
shape
[
0
])
np_example
[
'cluster_bias_mask'
]
=
np
.
zeros
(
np_example
[
'msa'
].
shape
[
0
])
np_example
[
'cluster_bias_mask'
][
0
]
=
1
np_example
[
'cluster_bias_mask'
][
0
]
=
1
# Initialize Bert mask with masked out off diagonals.
# Initialize Bert mask with masked out off diagonals.
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
msa_masks
=
[
x
in
np_chains_list
]
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
msa_masks_all_seq
=
[
np
.
ones
(
x
[
'msa_all_seq'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
x
in
np_chains_list
]
]
msa_masks_all_seq
=
[
np
.
ones
(
x
[
'msa_all_seq'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_mask_block_diag
=
block_diag
(
msa_mask_block_diag
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
*
msa_masks
,
pad_value
=
0
)
msa_mask_all_seq
=
np
.
concatenate
(
msa_masks_all_seq
,
axis
=
1
)
msa_mask_all_seq
=
np
.
concatenate
(
msa_masks_all_seq
,
axis
=
1
)
np_example
[
'bert_mask'
]
=
np
.
concatenate
(
np_example
[
'bert_mask'
]
=
np
.
concatenate
(
[
msa_mask_all_seq
,
msa_mask_block_diag
],
axis
=
0
)
[
msa_mask_all_seq
,
msa_mask_block_diag
],
axis
=
0
)
return
np_example
return
np_example
...
...
openfold/data/parsers.py
View file @
4bd1b4d5
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Functions for parsing various file formats."""
"""Functions for parsing various file formats."""
import
collections
import
collections
import
dataclasses
import
dataclasses
import
itertools
import
re
import
re
import
string
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
...
@@ -29,8 +30,7 @@ class Msa:
...
@@ -29,8 +30,7 @@ class Msa:
"""Class representing a parsed MSA file"""
"""Class representing a parsed MSA file"""
sequences
:
Sequence
[
str
]
sequences
:
Sequence
[
str
]
deletion_matrix
:
DeletionMatrix
deletion_matrix
:
DeletionMatrix
descriptions
:
Sequence
[
str
]
descriptions
:
Optional
[
Sequence
[
str
]]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
(
not
(
if
(
not
(
...
@@ -642,3 +642,20 @@ def parse_hmmsearch_a3m(
...
@@ -642,3 +642,20 @@ def parse_hmmsearch_a3m(
hits
.
append
(
hit
)
hits
.
append
(
hit
)
return
hits
return
hits
def
parse_hmmsearch_sto
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
openfold/data/templates.py
View file @
4bd1b4d5
...
@@ -220,13 +220,6 @@ def _assess_hhsearch_hit(
...
@@ -220,13 +220,6 @@ def _assess_hhsearch_hit(
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
date
=
release_dates
[
hit_pdb_code
.
upper
()]
date
=
release_dates
[
hit_pdb_code
.
upper
()]
raise
DateError
(
raise
DateError
(
...
@@ -240,6 +233,13 @@ def _assess_hhsearch_hit(
...
@@ -240,6 +233,13 @@ def _assess_hhsearch_hit(
f
"Align ratio:
{
align_ratio
}
."
f
"Align ratio:
{
align_ratio
}
."
)
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
duplicate
:
if
duplicate
:
raise
DuplicateError
(
raise
DuplicateError
(
"Template is an exact subsequence of query with large "
"Template is an exact subsequence of query with large "
...
@@ -770,7 +770,7 @@ def _prefilter_hit(
...
@@ -770,7 +770,7 @@ def _prefilter_hit(
except
PrefilterError
as
e
:
except
PrefilterError
as
e
:
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
logging
.
info
(
msg
)
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
# In strict mode we treat some prefilter cases as errors.
# In strict mode we treat some prefilter cases as errors.
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
...
@@ -826,6 +826,7 @@ def _process_single_hit(
...
@@ -826,6 +826,7 @@ def _process_single_hit(
query_sequence
,
query_sequence
,
template_sequence
,
template_sequence
,
)
)
# Fail if we can't find the mmCIF file.
# Fail if we can't find the mmCIF file.
cif_string
=
_read_file
(
cif_path
)
cif_string
=
_read_file
(
cif_path
)
...
@@ -968,7 +969,7 @@ class TemplateHitFeaturizer(abc.ABC):
...
@@ -968,7 +969,7 @@ class TemplateHitFeaturizer(abc.ABC):
raise
ValueError
(
raise
ValueError
(
"max_template_date must be set and have format YYYY-MM-DD."
"max_template_date must be set and have format YYYY-MM-DD."
)
)
self
.
max_hits
=
max_hits
self
.
_
max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
self
.
_strict_error_check
=
strict_error_check
...
@@ -997,33 +998,23 @@ class TemplateHitFeaturizer(abc.ABC):
...
@@ -997,33 +998,23 @@ class TemplateHitFeaturizer(abc.ABC):
query_sequence
:
str
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
""" Computes the templates for a given query sequence """
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
def
get_templates
(
self
,
self
,
query_sequence
:
str
,
query_sequence
:
str
,
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Searching for template for: %s"
,
query_
pdb_cod
e
)
logging
.
info
(
"Searching for template for: %s"
,
query_
sequenc
e
)
template_features
=
{}
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
template_features
[
template_feature_name
]
=
[]
# Always use a max_template_date. Set to query_release_date minus 60 days
already_seen
=
set
()
# if that's earlier.
template_cutoff_date
=
self
.
_max_template_date
if
query_release_date
:
delta
=
datetime
.
timedelta
(
days
=
60
)
if
query_release_date
-
delta
<
template_cutoff_date
:
template_cutoff_date
=
query_release_date
-
delta
assert
template_cutoff_date
<
query_release_date
assert
template_cutoff_date
<=
self
.
_max_template_date
num_hits
=
0
errors
=
[]
errors
=
[]
warnings
=
[]
warnings
=
[]
...
@@ -1032,7 +1023,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1032,7 +1023,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
prefilter_result
=
_prefilter_hit
(
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
hit
=
hit
,
hit
=
hit
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
...
@@ -1057,17 +1048,16 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1057,17 +1048,16 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for
i
in
idx
:
for
i
in
idx
:
# We got all the templates we wanted, stop processing hits.
# We got all the templates we wanted, stop processing hits.
if
num_hits
>=
self
.
max_hits
:
if
len
(
already_seen
)
>=
self
.
max_hits
:
break
break
hit
=
filtered
[
i
]
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
...
@@ -1091,8 +1081,10 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1091,8 +1081,10 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
result
.
warning
,
result
.
warning
,
)
)
else
:
else
:
# Increment the hit counter, since we got features out of this hit.
already_seen_key
=
result
.
features
[
"template_sequence"
]
num_hits
+=
1
if
(
already_seen_key
in
already_seen
):
continue
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
template_features
[
k
].
append
(
result
.
features
[
k
])
...
@@ -1118,6 +1110,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1118,6 +1110,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
query_sequence
:
str
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
logging
.
info
(
"Searching for template for: %s"
,
query_sequence
)
template_features
=
{}
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
template_features
[
template_feature_name
]
=
[]
...
@@ -1126,15 +1120,43 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1126,15 +1120,43 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
errors
=
[]
errors
=
[]
warnings
=
[]
warnings
=
[]
if
not
hits
or
hits
[
0
].
sum_probs
is
None
:
# DISCREPANCY: This filtering scheme that saves time
sorted_hits
=
hits
filtered
=
[]
else
:
for
hit
in
hits
:
sorted_hits
=
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
)
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
)
if
prefilter_result
.
error
:
errors
.
append
(
prefilter_result
.
error
)
if
prefilter_result
.
warning
:
warnings
.
append
(
prefilter_result
.
warning
)
if
prefilter_result
.
valid
:
filtered
.
append
(
hit
)
for
hit
in
sorted_hits
:
filtered
=
list
(
sorted
(
filtered
,
key
=
lambda
x
:
x
.
sum_probs
if
x
.
sum_probs
else
0.
,
reverse
=
True
)
)
idx
=
list
(
range
(
len
(
filtered
)))
if
(
self
.
_shuffle_top_k_prefiltered
):
stk
=
self
.
_shuffle_top_k_prefiltered
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
break
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
hit
=
hit
,
hit
=
hit
,
...
...
openfold/data/tools/hhsearch.py
View file @
4bd1b4d5
...
@@ -18,7 +18,7 @@ import glob
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
,
Optional
from
openfold.data
import
parsers
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -71,11 +71,12 @@ class HHSearch:
...
@@ -71,11 +71,12 @@ class HHSearch:
def
input_format
(
self
)
->
str
:
def
input_format
(
self
)
->
str
:
return
'a3m'
return
'a3m'
def
query
(
self
,
a3m
:
str
)
->
str
:
def
query
(
self
,
a3m
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
hhr_path
=
os
.
path
.
join
(
output_dir
,
"hhsearch_output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
f
.
write
(
a3m
)
...
@@ -114,7 +115,8 @@ class HHSearch:
...
@@ -114,7 +115,8 @@ class HHSearch:
hhr
=
f
.
read
()
hhr
=
f
.
read
()
return
hhr
return
hhr
def
get_template_hits
(
self
,
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
output_string
:
str
,
input_sequence
:
str
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
)
->
Sequence
[
parsers
.
TemplateHit
]:
...
...
openfold/data/tools/hmmsearch.py
View file @
4bd1b4d5
...
@@ -32,7 +32,8 @@ class Hmmsearch(object):
...
@@ -32,7 +32,8 @@ class Hmmsearch(object):
binary_path
:
str
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
"""Initializes the Python hmmsearch wrapper.
Args:
Args:
...
@@ -71,17 +72,23 @@ class Hmmsearch(object):
...
@@ -71,17 +72,23 @@ class Hmmsearch(object):
def
input_format
(
self
)
->
str
:
def
input_format
(
self
)
->
str
:
return
'sto'
return
'sto'
def
query
(
self
,
msa_sto
:
str
)
->
str
:
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
model_construction
=
'hand'
)
msa_sto
,
return
self
.
query_with_hmm
(
hmm
)
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
)
->
str
:
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
out_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
'hmm_output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
f
.
write
(
hmm
)
...
@@ -117,18 +124,14 @@ class Hmmsearch(object):
...
@@ -117,18 +124,14 @@ class Hmmsearch(object):
return
out_msa
return
out_msa
def
get_template_hits
(
self
,
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
output_string
:
str
,
input_sequence
:
str
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
parsers
.
convert_stockholm_to_a3m
(
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
output_string
,
remove_first_row_gaps
=
False
input_sequence
,
)
template_hits
=
parsers
.
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
)
return
template_hits
return
template_hits
openfold/model/embedders.py
View file @
4bd1b4d5
...
@@ -13,12 +13,26 @@
...
@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
from
openfold.utils
import
all_atom_multimer
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
dgram_from_positions
,
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
one_hot
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
openfold.utils
import
geometry
from
openfold.utils.tensor_utils
import
one_hot
,
tensor_tree_map
,
dict_multimap
class
InputEmbedder
(
nn
.
Module
):
class
InputEmbedder
(
nn
.
Module
):
...
@@ -85,20 +99,16 @@ class InputEmbedder(nn.Module):
...
@@ -85,20 +99,16 @@ class InputEmbedder(nn.Module):
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
def
forward
(
self
,
batch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
,
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
tf:
batch: Dict containing
"target_feat" features of shape [*, N_res, tf_dim]
"target_feat":
ri:
Features of shape [*, N_res, tf_dim]
"residue_index" features of shape [*, N_res]
"residue_index":
msa:
Features of shape [*, N_res]
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
"msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns:
Returns:
msa_emb:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
[*, N_clust, N_res, C_m] MSA embedding
...
@@ -106,6 +116,10 @@ class InputEmbedder(nn.Module):
...
@@ -106,6 +116,10 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
"""
"""
tf
=
batch
[
"target_feat"
]
ri
=
batch
[
"residue_index"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
...
@@ -126,6 +140,154 @@ class InputEmbedder(nn.Module):
...
@@ -126,6 +140,154 @@ class InputEmbedder(nn.Module):
return
msa_emb
,
pair_emb
return
msa_emb
,
pair_emb
class
InputEmbedderMultimer
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
max_relative_idx
:
int
,
use_chain_relative
:
bool
,
max_relative_chain
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedderMultimer
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
max_relative_idx
=
max_relative_idx
self
.
use_chain_relative
=
use_chain_relative
self
.
max_relative_chain
=
max_relative_chain
if
(
self
.
use_chain_relative
):
self
.
no_bins
=
(
2
*
max_relative_idx
+
2
+
1
+
2
*
max_relative_chain
+
2
)
else
:
self
.
no_bins
=
2
*
max_relative_idx
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
batch
):
pos
=
batch
[
"residue_index"
]
asym_id
=
batch
[
"asym_id"
]
asym_id_same
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:])
offset
=
pos
[...,
None
]
-
pos
[...,
None
,
:]
clipped_offset
=
torch
.
clamp
(
offset
+
self
.
max_relative_idx
,
0
,
2
*
self
.
max_relative_idx
)
rel_feats
=
[]
if
(
self
.
use_chain_relative
):
final_offset
=
torch
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
)
)
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
final_offset
,
2
*
self
.
max_relative_idx
+
2
,
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
(
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:])
rel_feats
.
append
(
entity_id_same
[...,
None
])
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
max_rel_chain
=
self
.
max_relative_chain
clipped_rel_chain
=
torch
.
clamp
(
rel_sym_id
+
max_rel_chain
,
0
,
2
*
max_rel_chain
,
)
final_rel_chain
=
torch
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
torch
.
ones_like
(
clipped_rel_chain
)
)
rel_chain
=
torch
.
nn
.
functional
.
one_hot
(
final_rel_chain
,
2
*
max_rel_chain
+
2
,
)
rel_feats
.
append
(
rel_chain
)
else
:
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
clipped_offset
,
2
*
self
.
max_relative_idx
+
1
,
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
torch
.
cat
(
rel_feats
,
dim
=-
1
).
to
(
self
.
linear_relpos
.
weight
.
dtype
)
return
self
.
linear_relpos
(
rel_feat
)
def
forward
(
self
,
batch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tf
=
batch
[
"target_feat"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
pair_emb
+
self
.
relpos
(
batch
)
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
RecyclingEmbedder
(
nn
.
Module
):
class
RecyclingEmbedder
(
nn
.
Module
):
"""
"""
Embeds the output of an iteration of the model for recycling.
Embeds the output of an iteration of the model for recycling.
...
@@ -312,6 +474,102 @@ class TemplatePairEmbedder(nn.Module):
...
@@ -312,6 +474,102 @@ class TemplatePairEmbedder(nn.Module):
return
x
return
x
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
):
super
().
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
,
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_pair_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
class
ExtraMSAEmbedder
(
nn
.
Module
):
class
ExtraMSAEmbedder
(
nn
.
Module
):
"""
"""
Embeds unclustered MSA sequences.
Embeds unclustered MSA sequences.
...
@@ -350,3 +608,315 @@ class ExtraMSAEmbedder(nn.Module):
...
@@ -350,3 +608,315 @@ class ExtraMSAEmbedder(nn.Module):
x
=
self
.
linear
(
x
)
x
=
self
.
linear
(
x
)
return
x
return
x
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_single_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_z
:
int
,
c_out
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
):
super
().
__init__
()
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_z
)
self
.
query_embedding_linear
=
Linear
(
c_z
,
c_out
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
)
self
.
x_linear
=
Linear
(
1
,
c_out
)
self
.
y_linear
=
Linear
(
1
,
c_out
)
self
.
z_linear
=
Linear
(
1
,
c_out
)
self
.
backbone_mask_linear
=
Linear
(
1
,
c_out
)
def
forward
(
self
,
template_dgram
:
torch
.
Tensor
,
aatype_one_hot
:
torch
.
Tensor
,
query_embedding
:
torch
.
Tensor
,
pseudo_beta_mask
:
torch
.
Tensor
,
backbone_mask
:
torch
.
Tensor
,
multichain_mask_2d
:
torch
.
Tensor
,
unit_vector
:
geometry
.
Vec3Array
,
)
->
torch
.
Tensor
:
act
=
0.
pseudo_beta_mask_2d
=
(
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
)
pseudo_beta_mask_2d
*=
multichain_mask_2d
template_dgram
*=
pseudo_beta_mask_2d
[...,
None
]
act
+=
self
.
dgram_linear
(
template_dgram
)
act
+=
self
.
pseudo_beta_mask_linear
(
pseudo_beta_mask_2d
[...,
None
])
aatype_one_hot
=
aatype_one_hot
.
to
(
template_dgram
.
dtype
)
act
+=
self
.
aatype_linear_1
(
aatype_one_hot
[...,
None
,
:,
:])
act
+=
self
.
aatype_linear_2
(
aatype_one_hot
[...,
None
,
:])
backbone_mask_2d
=
(
backbone_mask
[...,
None
]
*
backbone_mask
[...,
None
,
:]
)
backbone_mask_2d
*=
multichain_mask_2d
x
,
y
,
z
=
[
coord
*
backbone_mask_2d
for
coord
in
unit_vector
]
act
+=
self
.
x_linear
(
x
[...,
None
])
act
+=
self
.
y_linear
(
y
[...,
None
])
act
+=
self
.
z_linear
(
z
[...,
None
])
act
+=
self
.
backbone_mask_linear
(
backbone_mask_2d
[...,
None
])
query_embedding
=
self
.
query_embedding_layer_norm
(
query_embedding
)
act
+=
self
.
query_embedding_linear
(
query_embedding
)
return
act
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_in
:
int
,
c_m
:
int
,
):
super
().
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_m
)
self
.
template_projector
=
Linear
(
c_m
,
c_m
)
def
forward
(
self
,
batch
,
atom_pos
,
aatype_one_hot
,
):
out
=
{}
template_chi_angles
,
template_chi_mask
=
(
all_atom_multimer
.
compute_chi_angles
(
atom_pos
,
batch
[
"template_all_atom_mask"
],
batch
[
"template_aatype"
],
)
)
template_features
=
torch
.
cat
(
[
aatype_one_hot
,
torch
.
sin
(
template_chi_angles
)
*
template_chi_mask
,
torch
.
cos
(
template_chi_angles
)
*
template_chi_mask
,
template_chi_mask
,
],
dim
=-
1
,
)
template_mask
=
template_chi_mask
[...,
0
]
template_activations
=
self
.
template_single_embedder
(
template_features
)
template_activations
=
torch
.
nn
.
functional
.
relu
(
template_activations
)
template_activations
=
self
.
template_projector
(
template_activations
,
)
out
[
"template_single_embedding"
]
=
(
template_activations
)
out
[
"template_mask"
]
=
template_mask
return
out
class
TemplateEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedderMultimer
,
self
).
__init__
()
self
.
config
=
config
self
.
template_pair_embedder
=
TemplatePairEmbedderMultimer
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_single_embedder
=
TemplateSingleEmbedderMultimer
(
**
config
[
"template_single_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
linear_t
=
Linear
(
config
.
c_t
,
config
.
c_z
)
def
forward
(
self
,
batch
,
z
,
padding_mask_2d
,
templ_dim
,
chunk_size
,
multichain_mask_2d
,
):
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
act
=
0.
template_positions
,
pseudo_beta_mask
=
(
single_template_feats
[
"template_pseudo_beta"
],
single_template_feats
[
"template_pseudo_beta_mask"
],
)
template_dgram
=
dgram_from_positions
(
template_positions
,
inf
=
self
.
config
.
inf
,
**
self
.
config
.
distogram
,
)
aatype_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
single_template_feats
[
"template_aatype"
],
22
,
)
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_tensor
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
single_template_feats
[
"template_aatype"
],
)
points
=
rigid
.
translation
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
pair_act
=
self
.
template_pair_embedder
(
template_dgram
,
aatype_one_hot
,
z
,
pseudo_beta_mask
,
backbone_mask
,
multichain_mask_2d
,
unit_vector
,
)
single_template_embeds
[
"template_pair_embedding"
]
=
pair_act
single_template_embeds
.
update
(
self
.
template_single_embedder
(
single_template_feats
,
atom_pos
,
aatype_one_hot
,
)
)
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
t
=
torch
.
nn
.
functional
.
relu
(
t
)
t
=
self
.
linear_t
(
t
)
template_embeds
[
"template_pair_embedding"
]
=
t
return
template_embeds
openfold/model/model.py
View file @
4bd1b4d5
...
@@ -17,28 +17,25 @@ from functools import partial
...
@@ -17,28 +17,25 @@ from functools import partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.data
import
data_transforms_multimer
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
pseudo_beta_fn
,
build_extra_msa_feat
,
build_extra_msa_feat
,
build_template_angle_feat
,
dgram_from_positions
,
build_template_pair_feat
,
atom14_to_atom37
,
atom14_to_atom37
,
)
)
from
openfold.model.embedders
import
(
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
RecyclingEmbedder
,
Template
Angle
Embedder
,
TemplateEmbedder
,
Template
Pair
Embedder
,
TemplateEmbedder
Multimer
,
ExtraMSAEmbedder
,
ExtraMSAEmbedder
,
)
)
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.heads
import
AuxiliaryHeads
from
openfold.model.heads
import
AuxiliaryHeads
import
openfold.np.residue_constants
as
residue_constants
import
openfold.np.residue_constants
as
residue_constants
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
compute_plddt
,
compute_plddt
,
)
)
...
@@ -69,24 +66,28 @@ class AlphaFold(nn.Module):
...
@@ -69,24 +66,28 @@ class AlphaFold(nn.Module):
extra_msa_config
=
config
.
extra_msa
extra_msa_config
=
config
.
extra_msa
# Main trunk + structure module
# Main trunk + structure module
if
(
self
.
globals
.
is_multimer
):
self
.
input_embedder
=
InputEmbedderMultimer
(
**
config
[
"input_embedder"
],
)
else
:
self
.
input_embedder
=
InputEmbedder
(
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
**
config
[
"input_embedder"
],
)
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
**
config
[
"recycling_embedder"
],
)
)
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
template_config
[
"template_angle_embedder"
],
if
(
self
.
globals
.
is_multimer
):
)
self
.
template_embedder
=
TemplateEmbedderMultimer
(
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
template_config
,
**
template_config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
template_config
[
"template_pair_stack"
],
)
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
else
:
**
template_config
[
"template_pointwise_attention"
],
self
.
template_embedder
=
TemplateEmbedder
(
template_config
,
)
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
**
extra_msa_config
[
"extra_msa_embedder"
],
)
)
...
@@ -96,7 +97,9 @@ class AlphaFold(nn.Module):
...
@@ -96,7 +97,9 @@ class AlphaFold(nn.Module):
self
.
evoformer
=
EvoformerStack
(
self
.
evoformer
=
EvoformerStack
(
**
config
[
"evoformer_stack"
],
**
config
[
"evoformer_stack"
],
)
)
self
.
structure_module
=
StructureModule
(
self
.
structure_module
=
StructureModule
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
config
[
"structure_module"
],
**
config
[
"structure_module"
],
)
)
...
@@ -106,71 +109,6 @@ class AlphaFold(nn.Module):
...
@@ -106,71 +109,6 @@ class AlphaFold(nn.Module):
self
.
config
=
config
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -197,11 +135,7 @@ class AlphaFold(nn.Module):
...
@@ -197,11 +135,7 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m]
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
m
,
z
=
self
.
input_embedder
(
feats
)
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
)
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
...
@@ -257,40 +191,74 @@ class AlphaFold(nn.Module):
...
@@ -257,40 +191,74 @@ class AlphaFold(nn.Module):
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
}
template_embeds
=
self
.
embed_templates
(
if
(
self
.
globals
.
is_multimer
):
asym_id
=
feats
[
"asym_id"
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
)
template_embeds
=
self
.
template_embedder
(
template_feats
,
template_feats
,
z
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
self
.
globals
.
chunk_size
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
self
.
config
.
template
.
embed_angles
:
if
(
self
.
config
.
template
.
embed_angles
or
(
self
.
globals
.
is_multimer
and
self
.
config
.
template
.
enabled
)
):
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_
a
ngle_embedding"
]],
[
m
,
template_embeds
[
"template_
si
ngle_embedding"
]],
dim
=-
3
dim
=-
3
)
)
# [*, S, N]
# [*, S, N]
if
(
not
self
.
globals
.
is_multimer
):
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
dim
=-
2
dim
=-
2
)
)
else
:
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
dim
=-
2
,
)
# Embed extra MSA features + merge with pairwise embeddings
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
if
(
self
.
globals
.
is_multimer
):
extra_msa_fn
=
data_transforms_multimer
.
build_extra_msa_feat
else
:
extra_msa_fn
=
build_extra_msa_feat
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
(
a
,
extra_msa_feat
,
z
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
extra_msa_feat
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
...
@@ -340,14 +308,14 @@ class AlphaFold(nn.Module):
...
@@ -340,14 +308,14 @@ class AlphaFold(nn.Module):
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_
embedder
.
template_
pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
for
b
in
self
.
extra_msa_stack
.
blocks
:
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
False
b
.
ckpt
=
False
def
_enable_activation_checkpointing
(
self
):
def
_enable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
template_
embedder
.
template_
pair_stack
.
blocks_per_ckpt
=
(
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
)
)
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
evoformer
.
blocks_per_ckpt
=
(
...
...
openfold/model/structure_module.py
View file @
4bd1b4d5
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.geometry.quat_rigid
import
QuatRigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
torsion_angles_to_frames
,
...
@@ -155,14 +158,14 @@ class PointProjection(nn.Module):
...
@@ -155,14 +158,14 @@ class PointProjection(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
c_hidden
:
int
,
c_hidden
:
int
,
num_points
:
int
,
num_points
:
int
,
no_heads
:
int
no_heads
:
int
,
return_local_points
:
bool
=
False
,
return_local_points
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
linear
=
Linear
(
c_hidden
,
3
*
num_points
)
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
def
forward
(
self
,
def
forward
(
self
,
activations
:
torch
.
Tensor
,
activations
:
torch
.
Tensor
,
...
@@ -171,11 +174,13 @@ class PointProjection(nn.Module):
...
@@ -171,11 +174,13 @@ class PointProjection(nn.Module):
# TODO: Needs to run in high precision during training
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
points_local
=
points_local
.
reshape
(
points_local
.
shape
[:
-
1
],
*
points_local
.
shape
[:
-
1
],
self
.
no_heads
,
self
.
no_heads
,
-
1
,
-
1
,
)
)
points_local
=
torch
.
split
(
points_local
,
3
,
dim
=-
1
)
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
points_local
=
Vec3Array
(
*
points_local
)
points_local
=
Vec3Array
(
*
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
...
@@ -184,7 +189,7 @@ class PointProjection(nn.Module):
...
@@ -184,7 +189,7 @@ class PointProjection(nn.Module):
return
points_global
return
points_global
# WEIGHTS CHANGED
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
"""
"""
Implements Algorithm 22.
Implements Algorithm 22.
...
@@ -199,6 +204,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -199,6 +204,7 @@ class InvariantPointAttention(nn.Module):
no_v_points
:
int
,
no_v_points
:
int
,
inf
:
float
=
1e5
,
inf
:
float
=
1e5
,
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
is_multimer
:
bool
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -225,14 +231,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -225,14 +231,14 @@ class InvariantPointAttention(nn.Module):
self
.
no_v_points
=
no_v_points
self
.
no_v_points
=
no_v_points
self
.
inf
=
inf
self
.
inf
=
inf
self
.
eps
=
eps
self
.
eps
=
eps
self
.
is_multimer
=
is_multimer
# These linear layers differ from their specifications in the
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Here as in the official source, they have bias and use the default
# Lecun initialization.
# Lecun initialization.
hc
=
self
.
c_hidden
*
self
.
no_heads
hc
=
self
.
c_hidden
*
self
.
no_heads
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
)
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
(
not
is_multimer
))
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_q_points
=
PointProjection
(
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
...
@@ -240,15 +246,25 @@ class InvariantPointAttention(nn.Module):
...
@@ -240,15 +246,25 @@ class InvariantPointAttention(nn.Module):
self
.
no_heads
self
.
no_heads
)
)
if
(
is_multimer
):
self
.
linear_k
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_v
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
linear_k_points
=
PointProjection
(
self
.
linear_k_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
self
.
no_qk_points
self
.
no_qk_points
,
self
.
no_heads
,
self
.
no_heads
,
)
)
self
.
linear_v_points
=
PointProjection
(
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
self
.
no_v_points
self
.
no_v_points
,
self
.
no_heads
,
)
else
:
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_kv_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
+
self
.
no_v_points
,
self
.
no_heads
,
self
.
no_heads
,
)
)
...
@@ -290,25 +306,48 @@ class InvariantPointAttention(nn.Module):
...
@@ -290,25 +306,48 @@ class InvariantPointAttention(nn.Module):
#######################################
#######################################
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
q
=
self
.
linear_q
(
s
)
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, 2 * C_hidden]
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
# [*, N_res, H, P_qk]
# [*, N_res, H, P_qk]
q_pts
=
self
.
linear_q_points
(
s
,
r
)
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
if
(
self
.
is_multimer
):
# [*, N_res, H * C_hidden]
k
=
self
.
linear_k
(
s
)
v
=
self
.
linear_v
(
s
)
# [*, N_res, H, C_hidden]
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, P_qk, 3]
# [*, N_res, H, P_qk, 3]
k_pts
=
self
.
linear_k_points
(
s
,
r
)
k_pts
=
self
.
linear_k_points
(
s
,
r
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v, 3]
v_pts
=
self
.
linear_v_points
(
s
,
r
)
v_pts
=
self
.
linear_v_points
(
s
,
r
)
else
:
# [*, N_res, H * 2 * C_hidden]
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, 2 * C_hidden]
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
kv_pts
=
self
.
linear_kv_points
(
s
,
r
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
kv_pts
,
[
self
.
no_qk_points
,
self
.
no_v_points
],
dim
=-
2
)
##########################
##########################
# Compute attention scores
# Compute attention scores
...
@@ -324,12 +363,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -324,12 +363,14 @@ class InvariantPointAttention(nn.Module):
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
for
c
in
q_pts
:
print
(
type
(
c
))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
pt_att
*
pt_att
+
self
.
eps
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
pt_att
=
sum
(
[
c
**
2
for
c
in
pt_att
])
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
)
...
@@ -364,9 +405,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -364,9 +405,7 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# happens in float32.
# [*, N_res, H, P_v]
# [*, N_res, H, P_v]
o_pt
=
v_pts
.
tensor_dot
(
o_pt
=
v_pts
*
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# [*, N_res, H, P_v]
# [*, N_res, H, P_v]
...
@@ -493,6 +532,7 @@ class StructureModule(nn.Module):
...
@@ -493,6 +532,7 @@ class StructureModule(nn.Module):
trans_scale_factor
,
trans_scale_factor
,
epsilon
,
epsilon
,
inf
,
inf
,
is_multimer
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -546,6 +586,7 @@ class StructureModule(nn.Module):
...
@@ -546,6 +586,7 @@ class StructureModule(nn.Module):
self
.
trans_scale_factor
=
trans_scale_factor
self
.
trans_scale_factor
=
trans_scale_factor
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
inf
=
inf
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
# To be lazily initialized later
# To be lazily initialized later
self
.
default_frames
=
None
self
.
default_frames
=
None
...
@@ -567,6 +608,7 @@ class StructureModule(nn.Module):
...
@@ -567,6 +608,7 @@ class StructureModule(nn.Module):
self
.
no_v_points
,
self
.
no_v_points
,
inf
=
self
.
inf
,
inf
=
self
.
inf
,
eps
=
self
.
epsilon
,
eps
=
self
.
epsilon
,
is_multimer
=
self
.
is_multimer
,
)
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
...
@@ -588,26 +630,61 @@ class StructureModule(nn.Module):
...
@@ -588,26 +630,61 @@ class StructureModule(nn.Module):
self
.
epsilon
,
self
.
epsilon
,
)
)
def
forward
(
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
self
,
if
self
.
default_frames
is
None
:
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
().
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
lit_positions
,
)
def
_forward_monomer
(
self
,
s
,
s
,
z
,
z
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
):
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if
mask
is
None
:
if
mask
is
None
:
# [*, N]
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
@@ -690,51 +767,97 @@ class StructureModule(nn.Module):
...
@@ -690,51 +767,97 @@ class StructureModule(nn.Module):
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
def
_forward_multimer
(
self
,
if
self
.
default_frames
is
None
:
s
,
self
.
default_frames
=
torch
.
tensor
(
z
,
restype_rigid_group_default_frame
,
aatype
,
dtype
=
float_dtype
,
mask
=
None
,
device
=
device
,
):
requires_grad
=
False
,
if
mask
is
None
:
)
# [*, N]
if
self
.
group_idx
is
None
:
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
# [*, N, C_s]
device
=
device
,
s
=
self
.
layer_norm_s
(
s
)
requires_grad
=
False
,
# [*, N, N, C_z]
z
=
self
.
layer_norm_z
(
z
)
# [*, N, C_s]
s_initial
=
s
s
=
self
.
linear_in
(
s
)
# [*, N]
rigids
=
Rigid3Array
.
identity
(
s
.
shape
[:
-
1
],
s
.
device
,
)
)
if
self
.
atom_mask
is
None
:
outputs
=
[]
self
.
atom_mask
=
torch
.
tensor
(
for
i
in
range
(
self
.
no_blocks
):
restype_atom14_mask
,
# [*, N, C_s]
dtype
=
float_dtype
,
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
device
=
device
,
s
=
self
.
ipa_dropout
(
s
)
requires_grad
=
False
,
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
rigids
=
rigids
@
self
.
bb_update
(
s
)
# [*, N, 7, 2]
unnormalized_angles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
rigids
.
scale_translation
(
self
.
trans_scale_factor
),
angles
,
aatype
,
)
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
restype_atom14_rigid_group_positions
,
all_frames_to_global
,
dtype
=
float_dtype
,
aatype
,
device
=
device
,
requires_grad
=
False
,
)
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
preds
=
{
# Lazily initialize the residue constants on the correct device
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor7
(),
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
# Separated purely to make testing less annoying
"unnormalized_angles"
:
unnormalized_angles
,
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
"angles"
:
angles
,
"positions"
:
pred_xyz
,
}
def
frames_and_literature_positions_to_atom14_pos
(
outputs
.
append
(
preds
)
self
,
r
,
f
# [*, N, 8] # [*, N]
if
i
<
(
self
.
no_blocks
-
1
):
rigids
=
rigids
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
return
outputs
def
forward
(
self
,
s
,
z
,
aatype
,
mask
=
None
,
):
):
# Lazily initialize the residue constants on the correct device
"""
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
().
device
)
Args:
return
frames_and_literature_positions_to_atom14_pos
(
s:
r
,
[*, N_res, C_s] single representation
f
,
z:
self
.
default_frames
,
[*, N_res, N_res, C_z] pair representation
self
.
group_idx
,
aatype:
self
.
atom_mask
,
[*, N_res] amino acid indices
self
.
lit_positions
,
mask:
)
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if
(
self
.
is_multimer
):
outputs
=
self
.
_forward_multimer
(
s
,
z
,
aatype
,
mask
)
else
:
outputs
=
self
.
_forward_monomer
(
s
,
z
,
aatype
,
mask
)
return
outputs
openfold/np/protein.py
View file @
4bd1b4d5
...
@@ -62,7 +62,7 @@ class Protein:
...
@@ -62,7 +62,7 @@ class Protein:
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
:
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
)
:
raise
ValueError
(
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
"chains because these cannot be written to PDB format"
...
...
openfold/utils/all_atom_multimer.py
0 → 100644
View file @
4bd1b4d5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from
functools
import
partial
from
typing
import
Dict
,
Text
,
Tuple
import
torch
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils
import
geometry
,
tensor_utils
import
numpy
as
np
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
def
get_rc_tensor
(
rc_np
,
aatype
):
return
torch
.
tensor
(
rc_np
,
device
=
aatype
.
device
)[
aatype
]
def
atom14_to_atom37
(
atom14_data
:
torch
.
Tensor
,
# (*, N, 14, ...)
aatype
:
torch
.
Tensor
# (*, N)
)
->
torch
.
Tensor
:
# (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_TO_ATOM14
,
aatype
)
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
atom37_data
=
tensor_utils
.
batched_gather
(
atom14_data
,
idx_atom37_to_atom14
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
)
atom37_mask
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_MASK
,
aatype
)
if
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
2
:
atom37_data
*=
atom37_mask
elif
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
3
:
atom37_data
*=
atom37_mask
[...,
None
].
astype
(
atom37_data
.
dtype
)
else
:
raise
ValueError
(
"Incorrectly shaped data"
)
return
atom37_data
def
atom37_to_atom14
(
aatype
,
all_atom_pos
,
all_atom_mask
):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM14_TO_ATOM37
,
aatype
)
no_batch_dims
=
len
(
aatype
.
shape
)
atom14_mask
=
tensor_utils
.
batched_gather
(
all_atom_mask
,
residx_atom14_to_atom37
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
,
).
to
(
torch
.
float32
)
# create a mask for known groundtruth positions
atom14_mask
*=
get_rc_tensor
(
rc
.
RESTYPE_ATOM14_MASK
,
aatype
)
# gather the groundtruth positions
atom14_positions
=
tensor_utils
.
batched_gather
(
all_atom_pos
,
residx_atom14_to_atom37
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
,
),
atom14_positions
=
atom14_mask
*
atom14_positions
return
atom14_positions
,
atom14_mask
def
get_alt_atom14
(
aatype
,
positions
:
torch
.
Tensor
,
mask
):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform
=
get_rc_tensor
(
rc
.
RENAMING_MATRICES
,
aatype
)
alternative_positions
=
torch
.
sum
(
positions
[...,
None
,
:]
*
renaming_transform
[...,
None
],
dim
=-
2
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask
=
torch
.
sum
(
mask
[...,
None
]
*
renaming_transform
,
dim
=-
2
)
return
alternative_positions
,
alternative_mask
def
atom37_to_frames
(
aatype
:
torch
.
Tensor
,
# (...)
all_atom_positions
:
torch
.
Tensor
,
# (..., 37)
all_atom_mask
:
torch
.
Tensor
,
# (..., 37)
)
->
Dict
[
Text
,
torch
.
Tensor
]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx
=
get_rc_tensor
(
rc
.
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
,
aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos
=
tensor_utils
.
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=
no_batch_dims
+
1
,
batch_dims
=
no_batch_dims
+
1
,
)
# Compute the Rigids.
point_on_neg_x_axis
=
base_atom_pos
[...,
:,
:,
0
]
origin
=
base_atom_pos
[...,
:,
:,
1
]
point_on_xy_plane
=
base_atom_pos
[...,
:,
:,
2
]
gt_rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
origin
-
point_on_neg_x_axis
,
point_on_xy_plane
-
origin
)
gt_frames
=
geometry
.
Rigid3Array
(
gt_rotation
,
origin
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists
=
get_rc_tensor
(
rc
.
RESTYPE_RIGIDGROUP_MASK
,
aatype
)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist
=
tensor_utils
.
batched_gather
(
# shape (N, 8, 3)
all_atom_mask
.
to
(
dtype
=
torch
.
float32
),
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
no_batch_dims
+
1
,
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
*
group_exists
# (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
8
,
1
,
1
])
rots
[
0
,
0
,
0
]
=
-
1
rots
[
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose_rotation
(
geometry
.
Rot3Array
.
from_array
(
torch
.
tensor
(
rots
,
device
=
aatype
.
device
)
)
)
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
restype_rigidgroup_rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
21
,
8
,
1
,
1
]
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous
=
torch
.
tensor
(
restype_rigidgroup_is_ambiguous
,
device
=
aatype
.
device
,
)[
aatype
]
ambiguity_rot
=
torch
.
tensor
(
restype_rigidgroup_rots
,
device
=
aatype
.
device
,
)[
aatype
]
ambiguity_rot
=
geometry
.
Rot3Array
.
from_array
(
torch
.
Tensor
(
ambiguity_rot
,
device
=
aatype
.
device
)
)
# Create the alternative ground truth frames.
alt_gt_frames
=
gt_frames
.
compose_rotation
(
ambiguity_rot
)
fix_shape
=
lambda
x
:
x
.
reshape
(
x
.
shape
[:
-
2
]
+
(
8
,))
# reshape back to original residue layout
gt_frames
=
fix_shape
(
gt_frames
)
gt_exists
=
fix_shape
(
gt_exists
)
group_exists
=
fix_shape
(
group_exists
)
residx_rigidgroup_is_ambiguous
=
fix_shape
(
residx_rigidgroup_is_ambiguous
)
alt_gt_frames
=
fix_shape
(
alt_gt_frames
)
return
{
'rigidgroups_gt_frames'
:
gt_frames
,
# Rigid (..., 8)
'rigidgroups_gt_exists'
:
gt_exists
,
# (..., 8)
'rigidgroups_group_exists'
:
group_exists
,
# (..., 8)
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
# (..., 8)
'rigidgroups_alt_gt_frames'
:
alt_gt_frames
,
# Rigid (..., 8)
}
def
torsion_angles_to_frames
(
aatype
:
torch
.
Tensor
,
# (N)
backb_to_global
:
geometry
.
Rigid3Array
,
# (N)
torsion_angles_sin_cos
:
torch
.
Tensor
# (N, 7, 2)
)
->
geometry
.
Rigid3Array
:
# (N, 8)
"""Compute rigid group frames from torsion angles."""
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m
=
get_rc_tensor
(
rc
.
restype_rigid_group_default_frame
,
aatype
)
default_frames
=
geometry
.
Rigid3Array
.
from_array4x4
(
m
)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles
=
torsion_angles_sin_cos
[...,
0
]
cos_angles
=
torsion_angles_sin_cos
[...,
1
]
# insert zero rotation for backbone group.
num_residues
=
aatype
.
shape
[
-
1
]
sin_angles
=
torch
.
cat
(
[
torch
.
zeros_like
(
aatype
).
unsqueeze
(),
sin_angles
,
],
dim
=-
1
)
cos_angles
=
torch
.
cat
(
[
torch
.
ones_like
(
aatype
).
unsqueeze
(),
cos_angles
],
dim
=-
1
)
zeros
=
torch
.
zeros_like
(
sin_angles
)
ones
=
torch
.
ones_like
(
sin_angles
)
# all_rots are geometry.Rot3Array with shape (..., N, 8)
all_rots
=
geometry
.
Rot3Array
(
ones
,
zeros
,
zeros
,
zeros
,
cos_angles
,
-
sin_angles
,
zeros
,
sin_angles
,
cos_angles
)
# Apply rotations to the frames.
all_frames
=
default_frames
.
compose_rotation
(
all_rots
)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb
=
all_frames
[...,
4
]
chi2_frame_to_backb
=
chi1_frame_to_backb
@
all_frames
[...,
5
]
chi3_frame_to_backb
=
chi2_frame_to_backb
@
all_frames
[...,
6
]
chi4_frame_to_backb
=
chi3_frame_to_backb
@
all_frames
[...,
7
]
all_frames_to_backb
=
Rigid3Array
.
cat
(
[
all_frames
[...,
0
:
5
],
chi2_frame_to_backb
[...,
None
],
chi3_frame_to_backb
[...,
None
],
chi4_frame_to_backb
[...,
None
]
],
dim
=-
1
)
# Create the global frames.
# shape (N, 8)
all_frames_to_global
=
backb_to_global
[...,
None
]
@
all_frames_to_backb
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
aatype
:
torch
.
Tensor
,
# (*, N)
all_frames_to_global
:
geometry
.
Rigid3Array
# (N, 8)
)
->
geometry
.
Vec3Array
:
# (*, N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx
=
get_rc_tensor
(
rc
.
restype_atom14_to_rigid_group
,
aatype
)
group_mask
=
torch
.
nn
.
functional
.
one_hot
(
residx_to_group_idx
,
num_classes
=
8
)
# shape (*, N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global
=
all_frames_to_global
[...,
None
,
:]
*
group_mask
map_atoms_to_global
=
map_atoms_to_global
.
map_tensor_fn
(
partial
(
torch
.
sum
,
dim
=-
1
)
)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions
=
geometry
.
Vec3Array
.
from_array
(
get_rc_tensor
(
rc
.
restype_atom14_rigid_group_positions
,
aatype
)
)
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions
=
map_atoms_to_global
.
apply_to_point
(
lit_positions
)
# Mask out non-existing atoms.
mask
=
get_rc_tensor
(
rc
.
restype_atom14_mask
,
aatype
)
pred_positions
=
pred_positions
*
mask
return
pred_positions
def
extreme_ca_ca_distance_violations
(
positions
:
geometry
.
Vec3Array
,
# (N, 37(14))
mask
:
torch
.
Tensor
,
# (N, 37(14))
residue_index
:
torch
.
Tensor
,
# (N)
max_angstrom_tolerance
=
1.5
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos
=
positions
[...,
:
-
1
,
1
]
# (N - 1,)
this_ca_mask
=
mask
[...,
:
-
1
,
1
]
# (N - 1)
next_ca_pos
=
positions
[...,
1
:,
1
]
# (N - 1,)
next_ca_mask
=
mask
[...,
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
).
astype
(
torch
.
float32
)
ca_ca_distance
=
geometry
.
euclidean_distance
(
this_ca_pos
,
next_ca_pos
,
eps
)
violations
=
(
ca_ca_distance
-
rc
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
return
tensor_utils
.
masked_mean
(
mask
=
mask
,
value
=
violations
,
dim
=-
1
)
def
get_chi_atom_indices
(
device
:
torch
.
device
):
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
]
)
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
torch
.
tensor
(
chi_atom_indices
,
device
=
device
)
def
compute_chi_angles
(
positions
:
geometry
.
Vec3Array
,
mask
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, rc.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, rc.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert
positions
.
shape
[
-
1
]
==
rc
.
atom_type_num
assert
mask
.
shape
[
-
1
]
==
rc
.
atom_type_num
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices
=
get_chi_atom_indices
(
aatype
.
device
)
# DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why
# theirs works.
aatype_gapless
=
torch
.
clamp
(
aatype
,
max
=
20
)
# Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4].
atom_indices
=
chi_atom_indices
[
aatype_gapless
]
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms
=
positions
.
map_tensor_fn
(
partial
(
tensor_utils
.
batched_gather
,
inds
=
atom_indices
,
dim
=-
1
,
no_batch_dims
=
no_batch_dims
+
1
)
)
a
,
b
,
c
,
d
=
[
chi_angle_atoms
[...,
i
]
for
i
in
range
(
4
)]
chi_angles
=
geometry
.
dihedral_angle
(
a
,
b
,
c
,
d
)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.0
,
0.0
,
0.0
,
0.0
])
chi_angles_mask
=
torch
.
tensor
(
chi_angles_mask
,
device
=
aatype
.
device
)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask
=
chi_angles_mask
[
aatype_gapless
]
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask
=
tensor_utils
.
batched_gather
(
mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
no_batch_dims
+
1
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
)
chi_mask
=
chi_mask
*
chi_angle_atoms_mask
.
to
(
torch
.
float32
)
return
chi_angles
,
chi_mask
def
make_transform_from_reference
(
a_xyz
:
geometry
.
Vec3Array
,
b_xyz
:
geometry
.
Vec3Array
,
c_xyz
:
geometry
.
Vec3Array
)
->
geometry
.
Rigid3Array
:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
c_xyz
-
b_xyz
,
a_xyz
-
b_xyz
)
return
geometry
.
Rigid3Array
(
rotation
,
b_xyz
)
def
make_backbone_affine
(
positions
:
geometry
.
Vec3Array
,
mask
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
)
->
Tuple
[
geometry
.
Rigid3Array
,
torch
.
Tensor
]:
a
=
rc
.
atom_order
[
'N'
]
b
=
rc
.
atom_order
[
'CA'
]
c
=
rc
.
atom_order
[
'C'
]
rigid_mask
=
(
mask
[...,
a
]
*
mask
[...,
b
]
*
mask
[...,
c
])
rigid
=
make_transform_from_reference
(
a_xyz
=
positions
[...,
a
],
b_xyz
=
positions
[...,
b
],
c_xyz
=
positions
[...,
c
],
)
return
rigid
,
rigid_mask
openfold/utils/argparse_utils.py
0 → 100644
View file @
4bd1b4d5
from
argparse
import
HelpFormatter
from
operator
import
attrgetter
class
ArgparseAlphabetizer
(
HelpFormatter
):
"""
Sorts the optional arguments of an argparse parser alphabetically
"""
@
staticmethod
def
sort_actions
(
actions
):
return
sorted
(
actions
,
key
=
attrgetter
(
"option_strings"
))
# Formats the help message
def
add_arguments
(
self
,
actions
):
actions
=
ArgparseAlphabetizer
.
sort_actions
(
actions
)
super
(
ArgparseAlphabetizer
,
self
).
add_arguments
(
actions
)
# Formats the usage message
def
add_usage
(
self
,
usage
,
actions
,
groups
,
prefix
=
None
):
actions
=
ArgparseAlphabetizer
.
sort_actions
(
actions
)
args
=
usage
,
actions
,
groups
,
prefix
super
(
ArgparseAlphabetizer
,
self
).
add_usage
(
*
args
)
def
remove_arguments
(
parser
,
args
):
for
arg
in
args
:
for
action
in
parser
.
_actions
:
opts
=
vars
(
action
)[
"option_strings"
]
if
(
arg
in
opts
):
parser
.
_handle_conflict_resolve
(
None
,
[(
arg
,
action
)])
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment