Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
56d5e39c
Commit
56d5e39c
authored
Jun 17, 2023
by
Geoffrey Yu
Browse files
Merge remote-tracking branch 'upstream/multimer' into multimer
parents
56b86074
51556d52
Changes
80
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2983 additions
and
284 deletions
+2983
-284
.gitignore
.gitignore
+2
-0
openfold/__init__.py
openfold/__init__.py
+1
-0
openfold/config.py
openfold/config.py
+287
-1
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+521
-153
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+70
-27
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+303
-0
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+21
-7
openfold/data/feature_processing_multimer.py
openfold/data/feature_processing_multimer.py
+244
-0
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+136
-0
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+17
-1
openfold/data/msa_identifiers.py
openfold/data/msa_identifiers.py
+91
-0
openfold/data/msa_pairing.py
openfold/data/msa_pairing.py
+483
-0
openfold/data/parsers.py
openfold/data/parsers.py
+290
-17
openfold/data/templates.py
openfold/data/templates.py
+187
-62
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+4
-4
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+23
-4
openfold/data/tools/hmmbuild.py
openfold/data/tools/hmmbuild.py
+137
-0
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+137
-0
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+28
-7
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+1
-1
No files found.
.gitignore
View file @
56d5e39c
.vscode/
.vscode/
.idea/
__pycache__/
__pycache__/
*.egg-info
*.egg-info
build
build
...
@@ -8,3 +9,4 @@ dist
...
@@ -8,3 +9,4 @@ dist
data
data
openfold/resources/
openfold/resources/
tests/test_data/
tests/test_data/
openfold/__init__.py
View file @
56d5e39c
from
.
import
model
from
.
import
model
from
.
import
utils
from
.
import
utils
from
.
import
data
from
.
import
np
from
.
import
np
from
.
import
resources
from
.
import
resources
...
...
openfold/config.py
View file @
56d5e39c
import
re
import
copy
import
copy
import
importlib
import
importlib
import
ml_collections
as
mlc
import
ml_collections
as
mlc
...
@@ -16,7 +17,7 @@ def enforce_config_constraints(config):
...
@@ -16,7 +17,7 @@ def enforce_config_constraints(config):
path
=
s
.
split
(
'.'
)
path
=
s
.
split
(
'.'
)
setting
=
config
setting
=
config
for
p
in
path
:
for
p
in
path
:
setting
=
setting
[
p
]
setting
=
setting
.
get
(
p
)
return
setting
return
setting
...
@@ -152,6 +153,48 @@ def model_config(
...
@@ -152,6 +153,48 @@ def model_config(
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
globals
.
bfloat16
=
True
c
.
globals
.
bfloat16_output
=
False
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_msa_clusters
=
252
c
.
data
.
predict
.
max_msa_clusters
=
252
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
model
.
evoformer_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
elif
name
==
'model_4_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
elif
name
==
'model_5_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
else
:
c
.
data
.
train
.
max_msa_clusters
=
508
c
.
data
.
predict
.
max_msa_clusters
=
508
c
.
data
.
train
.
max_extra_msa
=
2048
c
.
data
.
predict
.
max_extra_msa
=
2048
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
])
else
:
else
:
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
...
@@ -380,6 +423,7 @@ config = mlc.ConfigDict(
...
@@ -380,6 +423,7 @@ config = mlc.ConfigDict(
"c_e"
:
c_e
,
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"eps"
:
eps
,
"is_multimer"
:
False
,
},
},
"model"
:
{
"model"
:
{
"_mask_trans"
:
False
,
"_mask_trans"
:
False
,
...
@@ -423,6 +467,8 @@ config = mlc.ConfigDict(
...
@@ -423,6 +467,8 @@ config = mlc.ConfigDict(
"no_heads"
:
4
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
False
,
"fuse_projection_weights"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"tune_chunk_size"
:
tune_chunk_size
,
"tune_chunk_size"
:
tune_chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e9
,
...
@@ -471,6 +517,8 @@ config = mlc.ConfigDict(
...
@@ -471,6 +517,8 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"fuse_projection_weights"
:
False
,
"clear_cache_between_blocks"
:
False
,
"clear_cache_between_blocks"
:
False
,
"tune_chunk_size"
:
tune_chunk_size
,
"tune_chunk_size"
:
tune_chunk_size
,
"inf"
:
1e9
,
"inf"
:
1e9
,
...
@@ -493,6 +541,8 @@ config = mlc.ConfigDict(
...
@@ -493,6 +541,8 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"fuse_projection_weights"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"clear_cache_between_blocks"
:
False
,
"tune_chunk_size"
:
tune_chunk_size
,
"tune_chunk_size"
:
tune_chunk_size
,
...
@@ -585,6 +635,7 @@ config = mlc.ConfigDict(
...
@@ -585,6 +635,7 @@ config = mlc.ConfigDict(
"weight"
:
0.01
,
"weight"
:
0.01
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
"weight"
:
2.0
,
},
},
...
@@ -597,6 +648,7 @@ config = mlc.ConfigDict(
...
@@ -597,6 +648,7 @@ config = mlc.ConfigDict(
"violation"
:
{
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
False
,
"eps"
:
eps
,
# 1e-6,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.0
,
"weight"
:
0.0
,
},
},
...
@@ -609,8 +661,242 @@ config = mlc.ConfigDict(
...
@@ -609,8 +661,242 @@ config = mlc.ConfigDict(
"weight"
:
0.
,
"weight"
:
0.
,
"enabled"
:
tm_enabled
,
"enabled"
:
tm_enabled
,
},
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.
,
"eps"
:
eps
,
"enabled"
:
False
,
},
"eps"
:
eps
,
"eps"
:
eps
,
},
},
"ema"
:
{
"decay"
:
0.999
},
"ema"
:
{
"decay"
:
0.999
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance"
:
-
1
}
}
)
)
multimer_model_config_update
=
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"msa_dim"
:
49
,
#"num_msa": 508,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_out"
:
64
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
True
},
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"c_in"
:
25
,
"c_out"
:
c_e
,
#"num_extra_msa": 2048
},
"extra_msa_stack"
:
{
"c_m"
:
c_e
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
8
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
"enabled"
:
True
,
},
"evoformer_stack"
:
{
"c_m"
:
c_m
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
32
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"c_s"
:
c_s
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
48
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
"structure_module"
:
{
"c_s"
:
c_s
,
"c_z"
:
c_z
,
"c_ipa"
:
16
,
"c_resnet"
:
128
,
"no_heads_ipa"
:
12
,
"no_qk_points"
:
4
,
"no_v_points"
:
8
,
"dropout_rate"
:
0.1
,
"no_blocks"
:
8
,
"no_transition_layers"
:
1
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
"trans_scale_factor"
:
20
,
"epsilon"
:
eps
,
# 1e-12,
"inf"
:
1e5
,
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"ptm_weight"
:
0.2
,
"iptm_weight"
:
0.8
,
"enabled"
:
True
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
# 1e-8,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.0
,
},
"fape"
:
{
"intra_chain_backbone"
:
{
"clamp_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
},
"interface"
:
{
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
},
"sidechain"
:
{
"clamp_distance"
:
10.0
,
"length_scale"
:
10.0
,
"weight"
:
0.5
,
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
},
"plddt_loss"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"no_bins"
:
50
,
"eps"
:
eps
,
# 1e-10,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
1.0
,
},
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
True
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.03
,
# Not finetuning
},
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.1
,
"enabled"
:
True
,
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.05
,
"eps"
:
eps
,
"enabled"
:
True
,
},
"eps"
:
eps
,
},
"recycle_early_stop_tolerance"
:
0.5
}
openfold/data/data_pipeline.py
View file @
56d5e39c
This diff is collapsed.
Click to expand it.
openfold/data/data_transforms.py
View file @
56d5e39c
...
@@ -23,6 +23,9 @@ import torch
...
@@ -23,6 +23,9 @@ import torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
as
rc
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -93,7 +96,7 @@ def fix_templates_aatype(protein):
...
@@ -93,7 +96,7 @@ def fix_templates_aatype(protein):
# Map hhsearch-aatype to our aatype.
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"aatype"
].
device
,
new_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"
template_
aatype"
].
device
,
).
expand
(
num_templates
,
-
1
)
).
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
...
@@ -439,13 +442,15 @@ def make_hhblits_profile(protein):
...
@@ -439,13 +442,15 @@ def make_hhblits_profile(protein):
@
curry1
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
,
seed
):
"""Create data for BERT on raw MSA."""
"""Create data for BERT on raw MSA."""
device
=
protein
[
"msa"
].
device
# Add a random amino acid uniformly.
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
(
random_aa
=
torch
.
tensor
(
[
0.05
]
*
20
+
[
0.0
,
0.0
],
[
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
protein
[
"aatype"
].
device
device
=
device
)
)
categorical_probs
=
(
categorical_probs
=
(
...
@@ -465,11 +470,17 @@ def make_masked_msa(protein, config, replace_fraction):
...
@@ -465,11 +470,17 @@ def make_masked_msa(protein, config, replace_fraction):
assert
mask_prob
>=
0.0
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
categorical_probs
,
pad_shapes
,
value
=
mask_prob
,
)
)
sh
=
protein
[
"msa"
].
shape
sh
=
protein
[
"msa"
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
sample
=
torch
.
rand
(
sh
,
device
=
device
,
generator
=
g
)
mask_position
=
sample
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
...
@@ -662,7 +673,7 @@ def make_atom14_masks(protein):
...
@@ -662,7 +673,7 @@ def make_atom14_masks(protein):
def
make_atom14_masks_np
(
batch
):
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
,
device
=
"cpu"
),
lambda
n
:
torch
.
tensor
(
n
,
device
=
"cpu"
),
batch
,
batch
,
np
.
ndarray
np
.
ndarray
)
)
out
=
make_atom14_masks
(
batch
)
out
=
make_atom14_masks
(
batch
)
...
@@ -728,7 +739,7 @@ def make_atom14_positions(protein):
...
@@ -728,7 +739,7 @@ def make_atom14_positions(protein):
for
index
,
correspondence
in
enumerate
(
correspondences
):
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.0
renaming_matrix
[
index
,
correspondence
]
=
1.0
all_matrices
[
resname
]
=
renaming_matrix
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
)
...
@@ -774,10 +785,14 @@ def make_atom14_positions(protein):
...
@@ -774,10 +785,14 @@ def make_atom14_positions(protein):
def
atom37_to_frames
(
protein
,
eps
=
1e-8
):
def
atom37_to_frames
(
protein
,
eps
=
1e-8
):
is_multimer
=
"asym_id"
in
protein
aatype
=
protein
[
"aatype"
]
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
if
is_multimer
:
all_atom_positions
=
Vec3Array
.
from_array
(
all_atom_positions
)
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
""
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
""
,
dtype
=
object
)
...
@@ -824,19 +839,37 @@ def atom37_to_frames(protein, eps=1e-8):
...
@@ -824,19 +839,37 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims
=
batch_dims
,
no_batch_dims
=
batch_dims
,
)
)
base_atom_pos
=
batched_gather
(
if
is_multimer
:
all_atom_positions
,
base_atom_pos
=
[
batched_gather
(
residx_rigidgroup_base_atom37_idx
,
pos
,
dim
=-
2
,
residx_rigidgroup_base_atom37_idx
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
dim
=-
1
,
)
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
1
]),
)
for
pos
in
all_atom_positions
]
base_atom_pos
=
Vec3Array
.
from_array
(
torch
.
stack
(
base_atom_pos
,
dim
=-
1
))
else
:
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
Rigid
.
from_3_points
(
if
is_multimer
:
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
point_on_neg_x_axis
=
base_atom_pos
[:,
:,
0
]
origin
=
base_atom_pos
[...,
1
,
:],
origin
=
base_atom_pos
[:,
:,
1
]
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
point_on_xy_plane
=
base_atom_pos
[:,
:,
2
]
eps
=
eps
,
gt_rotation
=
Rot3Array
.
from_two_vectors
(
)
origin
-
point_on_neg_x_axis
,
point_on_xy_plane
-
origin
)
gt_frames
=
Rigid3Array
(
gt_rotation
,
origin
)
else
:
gt_frames
=
Rigid
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
group_exists
=
batched_gather
(
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
restype_rigidgroup_mask
,
...
@@ -857,9 +890,13 @@ def atom37_to_frames(protein, eps=1e-8):
...
@@ -857,9 +890,13 @@ def atom37_to_frames(protein, eps=1e-8):
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
if
is_multimer
:
gt_frames
=
gt_frames
.
compose_rotation
(
Rot3Array
.
from_array
(
rots
))
else
:
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
*
((
1
,)
*
batch_dims
),
21
,
8
...
@@ -893,12 +930,18 @@ def atom37_to_frames(protein, eps=1e-8):
...
@@ -893,12 +930,18 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims
=
batch_dims
,
no_batch_dims
=
batch_dims
,
)
)
residx_rigidgroup_ambiguity_rot
=
Rotation
(
if
is_multimer
:
rot_mats
=
residx_rigidgroup_ambiguity_rot
ambiguity_rot
=
Rot3Array
.
from_array
(
residx_rigidgroup_ambiguity_rot
)
)
alt_gt_frames
=
gt_frames
.
compose
(
# Create the alternative ground truth frames.
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
alt_gt_frames
=
gt_frames
.
compose_rotation
(
ambiguity_rot
)
)
else
:
residx_rigidgroup_ambiguity_rot
=
Rotation
(
rot_mats
=
residx_rigidgroup_ambiguity_rot
)
alt_gt_frames
=
gt_frames
.
compose
(
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
)
gt_frames_tensor
=
gt_frames
.
to_tensor_4x4
()
gt_frames_tensor
=
gt_frames
.
to_tensor_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_tensor_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_tensor_4x4
()
...
...
openfold/data/data_transforms_multimer.py
0 → 100644
View file @
56d5e39c
from
typing
import
Sequence
import
torch
from
openfold.data.data_transforms
import
curry1
from
openfold.utils.tensor_utils
import
masked_mean
def
gumbel_noise
(
shape
:
Sequence
[
int
],
device
:
torch
.
device
,
eps
=
1e-6
,
generator
=
None
,
)
->
torch
.
Tensor
:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise
=
torch
.
rand
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
,
generator
=
generator
)
gumbel
=
-
torch
.
log
(
-
torch
.
log
(
uniform_noise
+
eps
)
+
eps
)
return
gumbel
def
gumbel_max_sample
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
nn
.
functional
.
one_hot
(
torch
.
argmax
(
logits
+
z
,
dim
=-
1
),
logits
.
shape
[
-
1
],
)
def
gumbel_argsort_sample_idx
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
argsort
(
logits
+
z
,
dim
=-
1
,
descending
=
True
)
@
curry1
def
make_masked_msa
(
batch
,
config
,
replace_fraction
,
seed
,
eps
=
1e-6
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
Tensor
(
[
0.05
]
*
20
+
[
0.
,
0.
],
device
=
batch
[
'msa'
].
device
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
batch
[
'msa_profile'
]
+
config
.
same_prob
*
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
[
0
,
1
],
value
=
mask_prob
)
sh
=
batch
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
,
device
=
batch
[
'msa'
].
device
)
<
replace_fraction
mask_position
*=
batch
[
'msa_mask'
].
to
(
mask_position
.
dtype
)
logits
=
torch
.
log
(
categorical_probs
+
eps
)
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
bert_msa
=
gumbel_max_sample
(
logits
,
generator
=
g
)
bert_msa
=
torch
.
where
(
mask_position
,
torch
.
argmax
(
bert_msa
,
dim
=-
1
),
batch
[
'msa'
]
)
bert_msa
*=
batch
[
'msa_mask'
].
to
(
bert_msa
.
dtype
)
# Mix real and masked MSA.
if
'bert_mask'
in
batch
:
batch
[
'bert_mask'
]
*=
mask_position
.
to
(
torch
.
float32
)
else
:
batch
[
'bert_mask'
]
=
mask_position
.
to
(
torch
.
float32
)
batch
[
'true_msa'
]
=
batch
[
'msa'
]
batch
[
'msa'
]
=
bert_msa
return
batch
@
curry1
def
nearest_neighbor_clusters
(
batch
,
gap_agreement_weight
=
0.
):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device
=
batch
[
"msa_mask"
].
device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights
=
torch
.
Tensor
(
[
1.
]
*
21
+
[
gap_agreement_weight
]
+
[
0.
],
device
=
device
,
)
msa_mask
=
batch
[
'msa_mask'
]
msa_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
extra_mask
=
batch
[
'extra_msa_mask'
]
extra_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'extra_msa'
],
23
)
msa_one_hot_masked
=
msa_mask
[:,
:,
None
]
*
msa_one_hot
extra_one_hot_masked
=
extra_mask
[:,
:,
None
]
*
extra_one_hot
agreement
=
torch
.
einsum
(
'mrc, nrc->nm'
,
extra_one_hot_masked
,
weights
*
msa_one_hot_masked
)
cluster_assignment
=
torch
.
nn
.
functional
.
softmax
(
1e3
*
agreement
,
dim
=
0
)
cluster_assignment
*=
torch
.
einsum
(
'mr, nr->mn'
,
msa_mask
,
extra_mask
)
cluster_count
=
torch
.
sum
(
cluster_assignment
,
dim
=-
1
)
cluster_count
+=
1.
# We always include the sequence itself.
msa_sum
=
torch
.
einsum
(
'nm, mrc->nrc'
,
cluster_assignment
,
extra_one_hot_masked
)
msa_sum
+=
msa_one_hot_masked
cluster_profile
=
msa_sum
/
cluster_count
[:,
None
,
None
]
extra_deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
deletion_matrix
=
batch
[
'deletion_matrix'
]
del_sum
=
torch
.
einsum
(
'nm, mc->nc'
,
cluster_assignment
,
extra_mask
*
extra_deletion_matrix
)
del_sum
+=
deletion_matrix
# Original sequence.
cluster_deletion_mean
=
del_sum
/
cluster_count
[:,
None
]
batch
[
'cluster_profile'
]
=
cluster_profile
batch
[
'cluster_deletion_mean'
]
=
cluster_deletion_mean
return
batch
def
create_target_feat
(
batch
):
"""Create the target features"""
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
21
).
to
(
torch
.
float32
)
return
batch
def
create_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
device
=
batch
[
"msa"
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
deletion_matrix
=
batch
[
'deletion_matrix'
]
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
deletion_mean_value
=
(
torch
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
pi
)
)[...,
None
]
msa_feat
=
torch
.
cat
(
[
msa_1hot
,
has_deletion
,
deletion_value
,
batch
[
'cluster_profile'
],
deletion_mean_value
],
dim
=-
1
,
)
batch
[
"msa_feat"
]
=
msa_feat
return
batch
def
build_extra_msa_feat
(
batch
):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa
=
batch
[
'extra_msa'
]
deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
extra_msa
,
23
)
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
)
extra_msa_mask
=
batch
[
'extra_msa_mask'
]
catted
=
torch
.
cat
([
msa_1hot
,
has_deletion
,
deletion_value
],
dim
=-
1
)
return
catted
@
curry1
def
sample_msa
(
batch
,
max_seq
,
max_extra_msa_seq
,
seed
,
inf
=
1e6
):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
# Sample uniformly among sequences with at least one non-masked position.
logits
=
(
torch
.
clamp
(
torch
.
sum
(
batch
[
'msa_mask'
],
dim
=-
1
),
0.
,
1.
)
-
1.
)
*
inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if
'cluster_bias_mask'
not
in
batch
:
cluster_bias_mask
=
torch
.
nn
.
functional
.
pad
(
batch
[
'msa'
].
new_zeros
(
batch
[
'msa'
].
shape
[
0
]
-
1
),
(
1
,
0
),
value
=
1.
)
else
:
cluster_bias_mask
=
batch
[
'cluster_bias_mask'
]
logits
+=
cluster_bias_mask
*
inf
index_order
=
gumbel_argsort_sample_idx
(
logits
,
generator
=
g
)
sel_idx
=
index_order
[:
max_seq
]
extra_idx
=
index_order
[
max_seq
:][:
max_extra_msa_seq
]
for
k
in
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'bert_mask'
]:
if
k
in
batch
:
batch
[
'extra_'
+
k
]
=
batch
[
k
][
extra_idx
]
batch
[
k
]
=
batch
[
k
][
sel_idx
]
return
batch
def
make_msa_profile
(
batch
):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch
[
"msa_profile"
]
=
masked_mean
(
batch
[
'msa_mask'
][...,
None
],
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
),
dim
=-
3
,
)
return
batch
openfold/data/feature_pipeline.py
View file @
56d5e39c
...
@@ -20,7 +20,7 @@ import ml_collections
...
@@ -20,7 +20,7 @@ import ml_collections
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
openfold.data
import
input_pipeline
from
openfold.data
import
input_pipeline
,
input_pipeline_multimer
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
...
@@ -74,8 +74,10 @@ def np_example_to_features(
...
@@ -74,8 +74,10 @@ def np_example_to_features(
np_example
:
FeatureDict
,
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
mode
:
str
,
is_multimer
:
bool
=
False
):
):
np_example
=
dict
(
np_example
)
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
...
@@ -88,11 +90,18 @@ def np_example_to_features(
...
@@ -88,11 +90,18 @@ def np_example_to_features(
np_example
=
np_example
,
features
=
feature_names
np_example
=
np_example
,
features
=
feature_names
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
if
(
not
is_multimer
):
tensor_dict
,
features
=
input_pipeline
.
process_tensors_from_config
(
cfg
.
common
,
tensor_dict
,
cfg
[
mode
],
cfg
.
common
,
)
cfg
[
mode
],
)
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
if
mode
==
"train"
:
if
mode
==
"train"
:
p
=
torch
.
rand
(
1
).
item
()
p
=
torch
.
rand
(
1
).
item
()
...
@@ -122,10 +131,15 @@ class FeaturePipeline:
...
@@ -122,10 +131,15 @@ class FeaturePipeline:
def
process_features
(
def
process_features
(
self
,
self
,
raw_features
:
FeatureDict
,
raw_features
:
FeatureDict
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
is_multimer
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
if
(
is_multimer
and
mode
!=
"predict"
):
raise
ValueError
(
"Multimer mode is not currently trainable"
)
return
np_example_to_features
(
return
np_example_to_features
(
np_example
=
raw_features
,
np_example
=
raw_features
,
config
=
self
.
config
,
config
=
self
.
config
,
mode
=
mode
,
mode
=
mode
,
is_multimer
=
is_multimer
,
)
)
openfold/data/feature_processing_multimer.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from
typing
import
Iterable
,
MutableMapping
,
List
,
Mapping
from
openfold.data
import
msa_pairing
from
openfold.np
import
residue_constants
import
numpy
as
np
# TODO: Move this into the config
REQUIRED_FEATURES
=
frozenset
({
'aatype'
,
'all_atom_mask'
,
'all_atom_positions'
,
'all_chains_entity_ids'
,
'all_crops_all_chains_mask'
,
'all_crops_all_chains_positions'
,
'all_crops_all_chains_residue_ids'
,
'assembly_num_chains'
,
'asym_id'
,
'bert_mask'
,
'cluster_bias_mask'
,
'deletion_matrix'
,
'deletion_mean'
,
'entity_id'
,
'entity_mask'
,
'mem_peak'
,
'msa'
,
'msa_mask'
,
'num_alignments'
,
'num_templates'
,
'queue_size'
,
'residue_index'
,
'resolution'
,
'seq_length'
,
'seq_mask'
,
'sym_id'
,
'template_aatype'
,
'template_all_atom_mask'
,
'template_all_atom_positions'
})
MAX_TEMPLATES
=
4
MSA_CROP_SIZE
=
2048
def
_is_homomer_or_monomer
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]])
->
bool
:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains
=
len
(
np
.
unique
(
np
.
concatenate
(
[
np
.
unique
(
chain
[
'entity_id'
][
chain
[
'entity_id'
]
>
0
])
for
chain
in
chains
])))
return
num_unique_chains
==
1
def
pair_and_merge
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features
(
all_chain_features
)
np_chains_list
=
list
(
all_chain_features
.
values
())
pair_msa_sequences
=
not
_is_homomer_or_monomer
(
np_chains_list
)
if
pair_msa_sequences
:
np_chains_list
=
msa_pairing
.
create_paired_features
(
chains
=
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
crop_chains
(
np_chains_list
,
msa_crop_size
=
MSA_CROP_SIZE
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
msa_pairing
.
merge_chain_features
(
np_chains_list
=
np_chains_list
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
process_final
(
np_example
)
return
np_example
def
crop_chains
(
chains_list
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains
=
[]
for
chain
in
chains_list
:
cropped_chain
=
_crop_single_chain
(
chain
,
msa_crop_size
=
msa_crop_size
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
max_templates
)
cropped_chains
.
append
(
cropped_chain
)
return
cropped_chains
def
_crop_single_chain
(
chain
:
Mapping
[
str
,
np
.
ndarray
],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size
=
chain
[
'num_alignments'
]
if
pair_msa_sequences
:
msa_size_all_seq
=
chain
[
'num_alignments_all_seq'
]
msa_crop_size_all_seq
=
np
.
minimum
(
msa_size_all_seq
,
msa_crop_size
//
2
)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq
=
chain
[
'msa_all_seq'
][:
msa_crop_size_all_seq
,
:]
num_non_gapped_pairs
=
np
.
sum
(
np
.
any
(
msa_all_seq
!=
msa_pairing
.
MSA_GAP_IDX
,
axis
=
1
))
num_non_gapped_pairs
=
np
.
minimum
(
num_non_gapped_pairs
,
msa_crop_size_all_seq
)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size
=
np
.
maximum
(
msa_crop_size
-
num_non_gapped_pairs
,
0
)
msa_crop_size
=
np
.
minimum
(
msa_size
,
max_msa_crop_size
)
else
:
msa_crop_size
=
np
.
minimum
(
msa_size
,
msa_crop_size
)
include_templates
=
'template_aatype'
in
chain
and
max_templates
if
include_templates
:
num_templates
=
chain
[
'template_aatype'
].
shape
[
0
]
templates_crop_size
=
np
.
minimum
(
num_templates
,
max_templates
)
for
k
in
chain
:
k_split
=
k
.
split
(
'_all_seq'
)[
0
]
if
k_split
in
msa_pairing
.
TEMPLATE_FEATURES
:
chain
[
k
]
=
chain
[
k
][:
templates_crop_size
,
:]
elif
k_split
in
msa_pairing
.
MSA_FEATURES
:
if
'_all_seq'
in
k
and
pair_msa_sequences
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size_all_seq
,
:]
else
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size
,
:]
chain
[
'num_alignments'
]
=
np
.
asarray
(
msa_crop_size
,
dtype
=
np
.
int32
)
if
include_templates
:
chain
[
'num_templates'
]
=
np
.
asarray
(
templates_crop_size
,
dtype
=
np
.
int32
)
if
pair_msa_sequences
:
chain
[
'num_alignments_all_seq'
]
=
np
.
asarray
(
msa_crop_size_all_seq
,
dtype
=
np
.
int32
)
return
chain
def
process_final
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example
=
_correct_msa_restypes
(
np_example
)
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_filter_features
(
np_example
)
return
np_example
def
_correct_msa_restypes
(
np_example
):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example
[
'msa'
]
=
np
.
take
(
new_order_list
,
np_example
[
'msa'
],
axis
=
0
)
np_example
[
'msa'
]
=
np_example
[
'msa'
].
astype
(
np
.
int32
)
return
np_example
def
_make_seq_mask
(
np_example
):
np_example
[
'seq_mask'
]
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
return
np_example
def
_make_msa_mask
(
np_example
):
"""Mask features are all ones, but will later be zero-padded."""
np_example
[
'msa_mask'
]
=
np
.
ones_like
(
np_example
[
'msa'
],
dtype
=
np
.
float32
)
seq_mask
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
np_example
[
'msa_mask'
]
*=
seq_mask
[
None
]
return
np_example
def
_filter_features
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Filters features of example to only those requested."""
return
{
k
:
v
for
(
k
,
v
)
in
np_example
.
items
()
if
k
in
REQUIRED_FEATURES
}
def
process_unmerged_features
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]
):
"""Postprocessing stage for per-chain features before merging."""
num_chains
=
len
(
all_chain_features
)
for
chain_features
in
all_chain_features
.
values
():
# Convert deletion matrices to float.
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
if
'deletion_matrix_int_all_seq'
in
chain_features
:
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_matrix'
],
axis
=
0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
chain_features
[
'aatype'
]]
chain_features
[
'all_atom_mask'
]
=
all_atom_mask
chain_features
[
'all_atom_positions'
]
=
np
.
zeros
(
list
(
all_atom_mask
.
shape
)
+
[
3
])
# Add assembly_num_chains.
chain_features
[
'assembly_num_chains'
]
=
np
.
asarray
(
num_chains
)
# Add entity_mask.
for
chain_features
in
all_chain_features
.
values
():
chain_features
[
'entity_mask'
]
=
(
chain_features
[
'entity_id'
]
!=
0
).
astype
(
np
.
int32
)
openfold/data/input_pipeline_multimer.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
from
openfold.data
import
(
data_transforms
,
data_transforms_multimer
,
)
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
if
(
common_cfg
.
use_templates
):
transforms
.
extend
([
data_transforms
.
make_pseudo_beta
(
"template_"
),
])
return
transforms
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
mode_cfg
.
max_extra_msa
msa_seed
=
None
if
(
not
common_cfg
.
resample_msa_in_recycling
):
msa_seed
=
ensemble_seed
transforms
.
append
(
data_transforms_multimer
.
sample_msa
(
max_msa_clusters
,
max_extra_msa
,
seed
=
msa_seed
,
)
)
if
"masked_msa"
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms_multimer
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
,
seed
=
(
msa_seed
+
1
)
if
msa_seed
else
None
,
)
)
transforms
.
append
(
data_transforms_multimer
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms_multimer
.
create_msa_feat
)
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
()
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
,
)
fn
=
compose
(
fns
)
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
openfold/data/mmcif_parsing.py
View file @
56d5e39c
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Parses the mmCIF file format."""
"""Parses the mmCIF file format."""
import
collections
import
collections
import
dataclasses
import
dataclasses
import
functools
import
io
import
io
import
json
import
json
import
logging
import
logging
...
@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
...
@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
parse
(
def
parse
(
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
)
->
ParsingResult
:
...
@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
...
@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution
=
parsed_info
[
res_key
][
0
]
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
"resolution"
]
=
float
(
raw_resolution
)
header
[
"resolution"
]
=
float
(
raw_resolution
)
except
ValueError
:
except
ValueError
:
logging
.
info
(
logging
.
debug
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
)
...
@@ -474,6 +476,20 @@ def get_atom_coords(
...
@@ -474,6 +476,20 @@ def get_atom_coords(
pos
[
residue_constants
.
atom_order
[
"SD"
]]
=
[
x
,
y
,
z
]
pos
[
residue_constants
.
atom_order
[
"SD"
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
"SD"
]]
=
1.0
mask
[
residue_constants
.
atom_order
[
"SD"
]]
=
1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd
=
residue_constants
.
atom_order
[
'CD'
]
nh1
=
residue_constants
.
atom_order
[
'NH1'
]
nh2
=
residue_constants
.
atom_order
[
'NH2'
]
if
(
res
.
get_resname
()
==
'ARG'
and
all
(
mask
[
atom_index
]
for
atom_index
in
(
cd
,
nh1
,
nh2
))
and
(
np
.
linalg
.
norm
(
pos
[
nh1
]
-
pos
[
cd
])
>
np
.
linalg
.
norm
(
pos
[
nh2
]
-
pos
[
cd
]))
):
pos
[
nh1
],
pos
[
nh2
]
=
pos
[
nh2
].
copy
(),
pos
[
nh1
].
copy
()
mask
[
nh1
],
mask
[
nh2
]
=
mask
[
nh2
].
copy
(),
mask
[
nh1
].
copy
()
all_atom_positions
[
res_index
]
=
pos
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
all_atom_mask
[
res_index
]
=
mask
...
...
openfold/data/msa_identifiers.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import
dataclasses
import
re
from
typing
import
Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN
=
re
.
compile
(
r
"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
"""
,
re
.
VERBOSE
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Identifiers
:
species_id
:
str
=
''
def
_parse_sequence_identifier
(
msa_sequence_identifier
:
str
)
->
Identifiers
:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
if
matches
:
return
Identifiers
(
species_id
=
matches
.
group
(
'SpeciesIdentifier'
)
)
return
Identifiers
()
def
_extract_sequence_identifier
(
description
:
str
)
->
Optional
[
str
]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description
=
description
.
split
()
if
split_description
:
return
split_description
[
0
].
partition
(
'/'
)[
0
]
else
:
return
None
def
get_identifiers
(
description
:
str
)
->
Identifiers
:
"""Computes extra MSA features from the description."""
sequence_identifier
=
_extract_sequence_identifier
(
description
)
if
sequence_identifier
is
None
:
return
Identifiers
()
else
:
return
_parse_sequence_identifier
(
sequence_identifier
)
openfold/data/msa_pairing.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import
collections
import
functools
import
string
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Sequence
,
Mapping
import
numpy
as
np
import
pandas
as
pd
import
scipy.linalg
from
openfold.np
import
residue_constants
# TODO: This stuff should probably also be in a config
MSA_GAP_IDX
=
residue_constants
.
restypes_with_x_and_gap
.
index
(
'-'
)
SEQUENCE_GAP_CUTOFF
=
0.5
SEQUENCE_SIMILARITY_CUTOFF
=
0.9
MSA_PAD_VALUES
=
{
'msa_all_seq'
:
MSA_GAP_IDX
,
'msa_mask_all_seq'
:
1
,
'deletion_matrix_all_seq'
:
0
,
'deletion_matrix_int_all_seq'
:
0
,
'msa'
:
MSA_GAP_IDX
,
'msa_mask'
:
1
,
'deletion_matrix'
:
0
,
'deletion_matrix_int'
:
0
}
MSA_FEATURES
=
(
'msa'
,
'msa_mask'
,
'deletion_matrix'
,
'deletion_matrix_int'
)
SEQ_FEATURES
=
(
'residue_index'
,
'aatype'
,
'all_atom_positions'
,
'all_atom_mask'
,
'seq_mask'
,
'between_segment_residues'
,
'has_alt_locations'
,
'has_hetatoms'
,
'asym_id'
,
'entity_id'
,
'sym_id'
,
'entity_mask'
,
'deletion_mean'
,
'prediction_atom_mask'
,
'literature_positions'
,
'atom_indices_to_group_indices'
,
'rigid_group_default_frame'
)
TEMPLATE_FEATURES
=
(
'template_aatype'
,
'template_all_atom_positions'
,
'template_all_atom_mask'
)
CHAIN_FEATURES
=
(
'num_alignments'
,
'seq_length'
)
def
create_paired_features
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]],
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains
=
list
(
chains
)
chain_keys
=
chains
[
0
].
keys
()
if
len
(
chains
)
<
2
:
return
chains
else
:
updated_chains
=
[]
paired_chains_to_paired_row_indices
=
pair_sequences
(
chains
)
paired_rows
=
reorder_paired_rows
(
paired_chains_to_paired_row_indices
)
for
chain_num
,
chain
in
enumerate
(
chains
):
new_chain
=
{
k
:
v
for
k
,
v
in
chain
.
items
()
if
'_all_seq'
not
in
k
}
for
feature_name
in
chain_keys
:
if
feature_name
.
endswith
(
'_all_seq'
):
feats_padded
=
pad_features
(
chain
[
feature_name
],
feature_name
)
new_chain
[
feature_name
]
=
feats_padded
[
paired_rows
[:,
chain_num
]]
new_chain
[
'num_alignments_all_seq'
]
=
np
.
asarray
(
len
(
paired_rows
[:,
chain_num
]))
updated_chains
.
append
(
new_chain
)
return
updated_chains
def
pad_features
(
feature
:
np
.
ndarray
,
feature_name
:
str
)
->
np
.
ndarray
:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert
feature
.
dtype
!=
np
.
dtype
(
np
.
string_
)
if
feature_name
in
(
'msa_all_seq'
,
'msa_mask_all_seq'
,
'deletion_matrix_all_seq'
,
'deletion_matrix_int_all_seq'
):
num_res
=
feature
.
shape
[
1
]
padding
=
MSA_PAD_VALUES
[
feature_name
]
*
np
.
ones
([
1
,
num_res
],
feature
.
dtype
)
elif
feature_name
==
'msa_species_identifiers_all_seq'
:
padding
=
[
b
''
]
else
:
return
feature
feats_padded
=
np
.
concatenate
([
feature
,
padding
],
axis
=
0
)
return
feats_padded
def
_make_msa_df
(
chain_features
:
Mapping
[
str
,
np
.
ndarray
])
->
pd
.
DataFrame
:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa
=
chain_features
[
'msa_all_seq'
]
query_seq
=
chain_msa
[
0
]
per_seq_similarity
=
np
.
sum
(
query_seq
[
None
]
==
chain_msa
,
axis
=-
1
)
/
float
(
len
(
query_seq
))
per_seq_gap
=
np
.
sum
(
chain_msa
==
21
,
axis
=-
1
)
/
float
(
len
(
query_seq
))
msa_df
=
pd
.
DataFrame
({
'msa_species_identifiers'
:
chain_features
[
'msa_species_identifiers_all_seq'
],
'msa_row'
:
np
.
arange
(
len
(
chain_features
[
'msa_species_identifiers_all_seq'
])),
'msa_similarity'
:
per_seq_similarity
,
'gap'
:
per_seq_gap
})
return
msa_df
def
_create_species_dict
(
msa_df
:
pd
.
DataFrame
)
->
Dict
[
bytes
,
pd
.
DataFrame
]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup
=
{}
for
species
,
species_df
in
msa_df
.
groupby
(
'msa_species_identifiers'
):
species_lookup
[
species
]
=
species_df
return
species_lookup
def
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
]
)
->
List
[
List
[
int
]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows
=
[]
num_seqs
=
[
len
(
species_df
)
for
species_df
in
this_species_msa_dfs
if
species_df
is
not
None
]
take_num_seqs
=
np
.
min
(
num_seqs
)
sort_by_similarity
=
(
lambda
x
:
x
.
sort_values
(
'msa_similarity'
,
axis
=
0
,
ascending
=
False
))
for
species_df
in
this_species_msa_dfs
:
if
species_df
is
not
None
:
species_df_sorted
=
sort_by_similarity
(
species_df
)
msa_rows
=
species_df_sorted
.
msa_row
.
iloc
[:
take_num_seqs
].
values
else
:
msa_rows
=
[
-
1
]
*
take_num_seqs
# take the last 'padding' row
all_paired_msa_rows
.
append
(
msa_rows
)
all_paired_msa_rows
=
list
(
np
.
array
(
all_paired_msa_rows
).
transpose
())
return
all_paired_msa_rows
def
pair_sequences
(
examples
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
)
->
Dict
[
int
,
np
.
ndarray
]:
"""Returns indices for paired MSA sequences across chains."""
num_examples
=
len
(
examples
)
all_chain_species_dict
=
[]
common_species
=
set
()
for
chain_features
in
examples
:
msa_df
=
_make_msa_df
(
chain_features
)
species_dict
=
_create_species_dict
(
msa_df
)
all_chain_species_dict
.
append
(
species_dict
)
common_species
.
update
(
set
(
species_dict
))
common_species
=
sorted
(
common_species
)
common_species
.
remove
(
b
''
)
# Remove target sequence species.
all_paired_msa_rows
=
[
np
.
zeros
(
len
(
examples
),
int
)]
all_paired_msa_rows_dict
=
{
k
:
[]
for
k
in
range
(
num_examples
)}
all_paired_msa_rows_dict
[
num_examples
]
=
[
np
.
zeros
(
len
(
examples
),
int
)]
for
species
in
common_species
:
if
not
species
:
continue
this_species_msa_dfs
=
[]
species_dfs_present
=
0
for
species_dict
in
all_chain_species_dict
:
if
species
in
species_dict
:
this_species_msa_dfs
.
append
(
species_dict
[
species
])
species_dfs_present
+=
1
else
:
this_species_msa_dfs
.
append
(
None
)
# Skip species that are present in only one chain.
if
species_dfs_present
<=
1
:
continue
if
np
.
any
(
np
.
array
([
len
(
species_df
)
for
species_df
in
this_species_msa_dfs
if
isinstance
(
species_df
,
pd
.
DataFrame
)])
>
600
):
continue
paired_msa_rows
=
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
)
all_paired_msa_rows
.
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
[
species_dfs_present
].
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
=
{
num_examples
:
np
.
array
(
paired_msa_rows
)
for
num_examples
,
paired_msa_rows
in
all_paired_msa_rows_dict
.
items
()
}
return
all_paired_msa_rows_dict
def
reorder_paired_rows
(
all_paired_msa_rows_dict
:
Dict
[
int
,
np
.
ndarray
]
)
->
np
.
ndarray
:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows
=
[]
for
num_pairings
in
sorted
(
all_paired_msa_rows_dict
,
reverse
=
True
):
paired_rows
=
all_paired_msa_rows_dict
[
num_pairings
]
paired_rows_product
=
abs
(
np
.
array
([
np
.
prod
(
rows
)
for
rows
in
paired_rows
]))
paired_rows_sort_index
=
np
.
argsort
(
paired_rows_product
)
all_paired_msa_rows
.
extend
(
paired_rows
[
paired_rows_sort_index
])
return
np
.
array
(
all_paired_msa_rows
)
def
block_diag
(
*
arrs
:
np
.
ndarray
,
pad_value
:
float
=
0.0
)
->
np
.
ndarray
:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs
=
[
np
.
ones_like
(
x
)
for
x
in
arrs
]
off_diag_mask
=
1.0
-
scipy
.
linalg
.
block_diag
(
*
ones_arrs
)
diag
=
scipy
.
linalg
.
block_diag
(
*
arrs
)
diag
+=
(
off_diag_mask
*
pad_value
).
astype
(
diag
.
dtype
)
return
diag
def
_correct_post_merged_feats
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_chains_list
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Adds features that need to be computed/recomputed post merging."""
num_res
=
np_example
[
'aatype'
].
shape
[
0
]
np_example
[
'seq_length'
]
=
np
.
asarray
(
[
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
np_example
[
'num_alignments'
]
=
np
.
asarray
(
np_example
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
if
not
pair_msa_sequences
:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks
=
[]
for
chain
in
np_chains_list
:
mask
=
np
.
zeros
(
chain
[
'msa'
].
shape
[
0
])
mask
[
0
]
=
1
cluster_bias_masks
.
append
(
mask
)
np_example
[
'cluster_bias_mask'
]
=
np
.
concatenate
(
cluster_bias_masks
)
# Initialize Bert mask with masked out off diagonals.
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
np_example
[
'bert_mask'
]
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
else
:
np_example
[
'cluster_bias_mask'
]
=
np
.
zeros
(
np_example
[
'msa'
].
shape
[
0
])
np_example
[
'cluster_bias_mask'
][
0
]
=
1
# Initialize Bert mask with masked out off diagonals.
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_masks_all_seq
=
[
np
.
ones
(
x
[
'msa_all_seq'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_mask_block_diag
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
msa_mask_all_seq
=
np
.
concatenate
(
msa_masks_all_seq
,
axis
=
1
)
np_example
[
'bert_mask'
]
=
np
.
concatenate
(
[
msa_mask_all_seq
,
msa_mask_block_diag
],
axis
=
0
)
return
np_example
def
_pad_templates
(
chains
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
max_templates
:
int
)
->
Sequence
[
Mapping
[
str
,
np
.
ndarray
]]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for
chain
in
chains
:
for
k
,
v
in
chain
.
items
():
if
k
in
TEMPLATE_FEATURES
:
padding
=
np
.
zeros_like
(
v
.
shape
)
padding
[
0
]
=
max_templates
-
v
.
shape
[
0
]
padding
=
[(
0
,
p
)
for
p
in
padding
]
chain
[
k
]
=
np
.
pad
(
v
,
padding
,
mode
=
'constant'
)
return
chains
def
_merge_features_from_multiple_chains
(
chains
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example
=
{}
for
feature_name
in
chains
[
0
]:
feats
=
[
x
[
feature_name
]
for
x
in
chains
]
feature_name_split
=
feature_name
.
split
(
'_all_seq'
)[
0
]
if
feature_name_split
in
MSA_FEATURES
:
if
pair_msa_sequences
or
'_all_seq'
in
feature_name
:
merged_example
[
feature_name
]
=
np
.
concatenate
(
feats
,
axis
=
1
)
else
:
merged_example
[
feature_name
]
=
block_diag
(
*
feats
,
pad_value
=
MSA_PAD_VALUES
[
feature_name
])
elif
feature_name_split
in
SEQ_FEATURES
:
merged_example
[
feature_name
]
=
np
.
concatenate
(
feats
,
axis
=
0
)
elif
feature_name_split
in
TEMPLATE_FEATURES
:
merged_example
[
feature_name
]
=
np
.
concatenate
(
feats
,
axis
=
1
)
elif
feature_name_split
in
CHAIN_FEATURES
:
merged_example
[
feature_name
]
=
np
.
sum
(
x
for
x
in
feats
).
astype
(
np
.
int32
)
else
:
merged_example
[
feature_name
]
=
feats
[
0
]
return
merged_example
def
_merge_homomers_dense_msa
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]])
->
Sequence
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains
=
collections
.
defaultdict
(
list
)
for
chain
in
chains
:
entity_id
=
chain
[
'entity_id'
][
0
]
entity_chains
[
entity_id
].
append
(
chain
)
grouped_chains
=
[]
for
entity_id
in
sorted
(
entity_chains
):
chains
=
entity_chains
[
entity_id
]
grouped_chains
.
append
(
chains
)
chains
=
[
_merge_features_from_multiple_chains
(
chains
,
pair_msa_sequences
=
True
)
for
chains
in
grouped_chains
]
return
chains
def
_concatenate_paired_and_unpaired_features
(
example
:
Mapping
[
str
,
np
.
ndarray
])
->
Mapping
[
str
,
np
.
ndarray
]:
"""Merges paired and block-diagonalised features."""
features
=
MSA_FEATURES
for
feature_name
in
features
:
if
feature_name
in
example
:
feat
=
example
[
feature_name
]
feat_all_seq
=
example
[
feature_name
+
'_all_seq'
]
merged_feat
=
np
.
concatenate
([
feat_all_seq
,
feat
],
axis
=
0
)
example
[
feature_name
]
=
merged_feat
example
[
'num_alignments'
]
=
np
.
array
(
example
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
return
example
def
merge_chain_features
(
np_chains_list
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list
=
_pad_templates
(
np_chains_list
,
max_templates
=
max_templates
)
np_chains_list
=
_merge_homomers_dense_msa
(
np_chains_list
)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example
=
_merge_features_from_multiple_chains
(
np_chains_list
,
pair_msa_sequences
=
False
)
if
pair_msa_sequences
:
np_example
=
_concatenate_paired_and_unpaired_features
(
np_example
)
np_example
=
_correct_post_merged_feats
(
np_example
=
np_example
,
np_chains_list
=
np_chains_list
,
pair_msa_sequences
=
pair_msa_sequences
)
return
np_example
def
deduplicate_unpaired_sequences
(
np_chains
:
List
[
Mapping
[
str
,
np
.
ndarray
]])
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names
=
np_chains
[
0
].
keys
()
msa_features
=
MSA_FEATURES
for
chain
in
np_chains
:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set
=
set
(
tuple
(
s
)
for
s
in
chain
[
'msa_all_seq'
])
keep_rows
=
[]
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for
row_num
,
seq
in
enumerate
(
chain
[
'msa'
]):
if
tuple
(
seq
)
not
in
sequence_set
:
keep_rows
.
append
(
row_num
)
for
feature_name
in
feature_names
:
if
feature_name
in
msa_features
:
chain
[
feature_name
]
=
chain
[
feature_name
][
keep_rows
]
chain
[
'num_alignments'
]
=
np
.
array
(
chain
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
return
np_chains
openfold/data/parsers.py
View file @
56d5e39c
...
@@ -16,14 +16,43 @@
...
@@ -16,14 +16,43 @@
"""Functions for parsing various file formats."""
"""Functions for parsing various file formats."""
import
collections
import
collections
import
dataclasses
import
dataclasses
import
itertools
import
re
import
re
import
string
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Msa
:
"""Class representing a parsed MSA file"""
sequences
:
Sequence
[
str
]
deletion_matrix
:
DeletionMatrix
descriptions
:
Optional
[
Sequence
[
str
]]
def
__post_init__
(
self
):
if
(
not
(
len
(
self
.
sequences
)
==
len
(
self
.
deletion_matrix
)
==
len
(
self
.
descriptions
)
)):
raise
ValueError
(
"All fields for an MSA must have the same length"
)
def
__len__
(
self
):
return
len
(
self
.
sequences
)
def
truncate
(
self
,
max_seqs
:
int
):
return
Msa
(
sequences
=
self
.
sequences
[:
max_seqs
],
deletion_matrix
=
self
.
deletion_matrix
[:
max_seqs
],
descriptions
=
self
.
descriptions
[:
max_seqs
],
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
class
TemplateHit
:
"""Class representing a template hit."""
"""Class representing a template hit."""
...
@@ -31,7 +60,7 @@ class TemplateHit:
...
@@ -31,7 +60,7 @@ class TemplateHit:
index
:
int
index
:
int
name
:
str
name
:
str
aligned_cols
:
int
aligned_cols
:
int
sum_probs
:
float
sum_probs
:
Optional
[
float
]
query
:
str
query
:
str
hit_sequence
:
str
hit_sequence
:
str
indices_query
:
List
[
int
]
indices_query
:
List
[
int
]
...
@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
...
@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return
sequences
,
descriptions
return
sequences
,
descriptions
def
parse_stockholm
(
def
parse_stockholm
(
stockholm_string
:
str
)
->
Msa
:
stockholm_string
:
str
,
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
Args:
...
@@ -126,10 +153,14 @@ def parse_stockholm(
...
@@ -126,10 +153,14 @@ def parse_stockholm(
deletion_count
=
0
deletion_count
=
0
deletion_matrix
.
append
(
deletion_vec
)
deletion_matrix
.
append
(
deletion_vec
)
return
msa
,
deletion_matrix
,
list
(
name_to_sequence
.
keys
())
return
Msa
(
sequences
=
msa
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
list
(
name_to_sequence
.
keys
())
)
def
parse_a3m
(
a3m_string
:
str
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
]
:
def
parse_a3m
(
a3m_string
:
str
)
->
Msa
:
"""Parses sequences and deletion matrix from a3m format alignment.
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
Args:
...
@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
...
@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
the aligned sequence i at residue position j.
"""
"""
sequences
,
_
=
parse_fasta
(
a3m_string
)
sequences
,
descriptions
=
parse_fasta
(
a3m_string
)
deletion_matrix
=
[]
deletion_matrix
=
[]
for
msa_sequence
in
sequences
:
for
msa_sequence
in
sequences
:
deletion_vec
=
[]
deletion_vec
=
[]
...
@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
...
@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences.
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table
=
str
.
maketrans
(
""
,
""
,
string
.
ascii_lowercase
)
deletion_table
=
str
.
maketrans
(
""
,
""
,
string
.
ascii_lowercase
)
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
return
aligned_sequences
,
deletion_matrix
return
Msa
(
sequences
=
aligned_sequences
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
descriptions
)
def
_convert_sto_seq_to_a3m
(
def
_convert_sto_seq_to_a3m
(
...
@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m(
...
@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m(
def
convert_stockholm_to_a3m
(
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
,
remove_first_row_gaps
:
bool
=
True
,
)
->
str
:
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
descriptions
=
{}
...
@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m(
...
@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m(
# Convert sto format to a3m line by line
# Convert sto format to a3m line by line
a3m_sequences
=
{}
a3m_sequences
=
{}
# query_sequence is assumed to be the first sequence
if
(
remove_first_row_gaps
):
query_sequence
=
next
(
iter
(
sequences
.
values
()))
# query_sequence is assumed to be the first sequence
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
""
.
join
(
# Dots are optional in a3m format and are commonly removed.
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
)
out_sequence
=
sto_sequence
.
replace
(
'.'
,
''
)
)
if
(
remove_first_row_gaps
):
out_sequence
=
''
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
out_sequence
)
)
a3m_sequences
[
seqname
]
=
out_sequence
fasta_chunks
=
(
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
...
@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m(
...
@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m(
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
def
_keep_line
(
line
:
str
,
seqnames
:
Set
[
str
])
->
bool
:
"""Function to decide which lines to keep."""
if
not
line
.
strip
():
return
True
if
line
.
strip
()
==
'//'
:
# End tag
return
True
if
line
.
startswith
(
'# STOCKHOLM'
):
# Start tag
return
True
if
line
.
startswith
(
'#=GC RF'
):
# Reference Annotation Line
return
True
if
line
[:
4
]
==
'#=GS'
:
# Description lines - keep if sequence in list.
_
,
seqname
,
_
=
line
.
split
(
maxsplit
=
2
)
return
seqname
in
seqnames
elif
line
.
startswith
(
'#'
):
# Other markup - filter out
return
False
else
:
# Alignment data - keep if sequence in list.
seqname
=
line
.
partition
(
' '
)[
0
]
return
seqname
in
seqnames
def
truncate_stockholm_msa
(
stockholm_msa_path
:
str
,
max_sequences
:
int
)
->
str
:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames
=
set
()
filtered_lines
=
[]
with
open
(
stockholm_msa_path
)
as
f
:
for
line
in
f
:
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname
=
line
.
partition
(
' '
)[
0
]
seqnames
.
add
(
seqname
)
if
len
(
seqnames
)
>=
max_sequences
:
break
f
.
seek
(
0
)
for
line
in
f
:
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
''
.
join
(
filtered_lines
)
def
remove_empty_columns_from_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines
=
{}
unprocessed_lines
=
{}
for
i
,
line
in
enumerate
(
stockholm_msa
.
splitlines
()):
if
line
.
startswith
(
'#=GC RF'
):
reference_annotation_i
=
i
reference_annotation_line
=
line
# Reached the end of this chunk of the alignment. Process chunk.
_
,
_
,
first_alignment
=
line
.
rpartition
(
' '
)
mask
=
[]
for
j
in
range
(
len
(
first_alignment
)):
for
_
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
if
alignment
[
j
]
!=
'-'
:
mask
.
append
(
True
)
break
else
:
# Every row contained a hyphen - empty column.
mask
.
append
(
False
)
# Add reference annotation for processing with mask.
unprocessed_lines
[
reference_annotation_i
]
=
reference_annotation_line
if
not
any
(
mask
):
# All columns were empty. Output empty lines for chunk.
for
line_index
in
unprocessed_lines
:
processed_lines
[
line_index
]
=
''
else
:
for
line_index
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
processed_lines
[
line_index
]
=
f
'
{
prefix
}
{
masked_alignment
}
'
# Clear raw_alignments.
unprocessed_lines
=
{}
elif
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
unprocessed_lines
[
i
]
=
line
else
:
processed_lines
[
i
]
=
line
return
'
\n
'
.
join
((
processed_lines
[
i
]
for
i
in
range
(
len
(
processed_lines
))))
def
deduplicate_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict
=
collections
.
defaultdict
(
str
)
# First we must extract all sequences from the MSA.
for
line
in
stockholm_msa
.
splitlines
():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
line
=
line
.
strip
()
seqname
,
alignment
=
line
.
split
()
sequence_dict
[
seqname
]
+=
alignment
seen_sequences
=
set
()
seqnames
=
set
()
# First alignment is the query.
query_align
=
next
(
iter
(
sequence_dict
.
values
()))
mask
=
[
c
!=
'-'
for
c
in
query_align
]
# Mask is False for insertions.
for
seqname
,
alignment
in
sequence_dict
.
items
():
# Apply mask to remove all insertions from the string.
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
if
masked_alignment
in
seen_sequences
:
continue
else
:
seen_sequences
.
add
(
masked_alignment
)
seqnames
.
add
(
seqname
)
filtered_lines
=
[]
for
line
in
stockholm_msa
.
splitlines
():
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
'
\n
'
.
join
(
filtered_lines
)
+
'
\n
'
def
_get_hhr_line_regex_groups
(
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
)
->
Sequence
[
Optional
[
str
]]:
...
@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
...
@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
%
(
detailed_lines
,
detailed_lines
[
2
])
%
(
detailed_lines
,
detailed_lines
[
2
])
)
)
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
(
_
,
_
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
_
)
=
[
float
(
x
)
for
x
in
match
.
groups
()
float
(
x
)
for
x
in
match
.
groups
()
]
]
...
@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
...
@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name
=
fields
[
0
]
target_name
=
fields
[
0
]
e_values
[
target_name
]
=
float
(
e_value
)
e_values
[
target_name
]
=
float
(
e_value
)
return
e_values
return
e_values
def
_get_indices
(
sequence
:
str
,
start
:
int
)
->
List
[
int
]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices
=
[]
counter
=
start
for
symbol
in
sequence
:
# Skip gaps but add a placeholder so that the alignment is preserved.
if
symbol
==
'-'
:
indices
.
append
(
-
1
)
# Skip deleted residues, but increase the counter.
elif
symbol
.
islower
():
counter
+=
1
# Normal aligned residue. Increase the counter and append to indices.
else
:
indices
.
append
(
counter
)
counter
+=
1
return
indices
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
HitMetadata
:
pdb_id
:
str
chain
:
str
start
:
int
end
:
int
length
:
int
text
:
str
def
_parse_hmmsearch_description
(
description
:
str
)
->
HitMetadata
:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match
=
re
.
match
(
r
'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$'
,
description
.
strip
())
if
not
match
:
raise
ValueError
(
f
'Could not parse description: "
{
description
}
".'
)
return
HitMetadata
(
pdb_id
=
match
[
1
],
chain
=
match
[
2
],
start
=
int
(
match
[
3
]),
end
=
int
(
match
[
4
]),
length
=
int
(
match
[
5
]),
text
=
match
[
6
]
)
def
parse_hmmsearch_a3m
(
query_sequence
:
str
,
a3m_string
:
str
,
skip_first
:
bool
=
True
)
->
Sequence
[
TemplateHit
]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m
=
list
(
zip
(
*
parse_fasta
(
a3m_string
)))
if
skip_first
:
parsed_a3m
=
parsed_a3m
[
1
:]
indices_query
=
_get_indices
(
query_sequence
,
start
=
0
)
hits
=
[]
for
i
,
(
hit_sequence
,
hit_description
)
in
enumerate
(
parsed_a3m
,
start
=
1
):
if
'mol:protein'
not
in
hit_description
:
continue
# Skip non-protein chains.
metadata
=
_parse_hmmsearch_description
(
hit_description
)
# Aligned columns are only the match states.
aligned_cols
=
sum
([
r
.
isupper
()
and
r
!=
'-'
for
r
in
hit_sequence
])
indices_hit
=
_get_indices
(
hit_sequence
,
start
=
metadata
.
start
-
1
)
hit
=
TemplateHit
(
index
=
i
,
name
=
f
'
{
metadata
.
pdb_id
}
_
{
metadata
.
chain
}
'
,
aligned_cols
=
aligned_cols
,
sum_probs
=
None
,
query
=
query_sequence
,
hit_sequence
=
hit_sequence
.
upper
(),
indices_query
=
indices_query
,
indices_hit
=
indices_hit
,
)
hits
.
append
(
hit
)
return
hits
def
parse_hmmsearch_sto
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
openfold/data/templates.py
View file @
56d5e39c
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
# limitations under the License.
# limitations under the License.
"""Functions for getting templates and calculating template features."""
"""Functions for getting templates and calculating template features."""
import
abc
import
dataclasses
import
dataclasses
import
datetime
import
datetime
import
functools
import
glob
import
glob
import
json
import
json
import
logging
import
logging
...
@@ -65,10 +67,6 @@ class DateError(PrefilterError):
...
@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
"""An error indicating that the hit date was after the max allowed date."""
class
PdbIdError
(
PrefilterError
):
"""An error indicating that the hit PDB ID was identical to the query."""
class
AlignRatioError
(
PrefilterError
):
class
AlignRatioError
(
PrefilterError
):
"""An error indicating that the hit align ratio to the query was too small."""
"""An error indicating that the hit align ratio to the query was too small."""
...
@@ -204,7 +202,6 @@ def _assess_hhsearch_hit(
...
@@ -204,7 +202,6 @@ def _assess_hhsearch_hit(
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
release_date_cutoff
:
datetime
.
datetime
,
max_subsequence_ratio
:
float
=
0.95
,
max_subsequence_ratio
:
float
=
0.95
,
...
@@ -218,7 +215,6 @@ def _assess_hhsearch_hit(
...
@@ -218,7 +215,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might
different from the value in the actual hit since the original pdb might
have become obsolete.
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
release_dates: Dictionary mapping pdb codes to their structure release
dates.
dates.
release_date_cutoff: Max release date that is valid for this query.
release_date_cutoff: Max release date that is valid for this query.
...
@@ -230,7 +226,6 @@ def _assess_hhsearch_hit(
...
@@ -230,7 +226,6 @@ def _assess_hhsearch_hit(
Raises:
Raises:
DateError: If the hit date was after the max allowed date.
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
LengthError: If the hit was too short.
...
@@ -241,13 +236,6 @@ def _assess_hhsearch_hit(
...
@@ -241,13 +236,6 @@ def _assess_hhsearch_hit(
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
date
=
release_dates
[
hit_pdb_code
.
upper
()]
date
=
release_dates
[
hit_pdb_code
.
upper
()]
raise
DateError
(
raise
DateError
(
...
@@ -255,16 +243,19 @@ def _assess_hhsearch_hit(
...
@@ -255,16 +243,19 @@ def _assess_hhsearch_hit(
f
"(
{
release_date_cutoff
}
)."
f
"(
{
release_date_cutoff
}
)."
)
)
if
query_pdb_code
is
not
None
:
if
query_pdb_code
.
lower
()
==
hit_pdb_code
.
lower
():
raise
PdbIdError
(
"PDB code identical to Query PDB code."
)
if
align_ratio
<=
min_align_ratio
:
if
align_ratio
<=
min_align_ratio
:
raise
AlignRatioError
(
raise
AlignRatioError
(
"Proportion of residues aligned to query too small. "
"Proportion of residues aligned to query too small. "
f
"Align ratio:
{
align_ratio
}
."
f
"Align ratio:
{
align_ratio
}
."
)
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
duplicate
:
if
duplicate
:
raise
DuplicateError
(
raise
DuplicateError
(
"Template is an exact subsequence of query with large "
"Template is an exact subsequence of query with large "
...
@@ -424,9 +415,10 @@ def _realign_pdb_template_to_query(
...
@@ -424,9 +415,10 @@ def _realign_pdb_template_to_query(
)
)
try
:
try
:
(
old_aligned_template
,
new_aligned_template
),
_
=
parsers
.
parse_a3m
(
parsed_a3m
=
parsers
.
parse_a3m
(
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
)
)
old_aligned_template
,
new_aligned_template
=
parsed_a3m
.
sequences
except
Exception
as
e
:
except
Exception
as
e
:
raise
QueryToTemplateAlignError
(
raise
QueryToTemplateAlignError
(
"Could not align old template %s to template %s (%s_%s). Error: %s"
"Could not align old template %s to template %s (%s_%s). Error: %s"
...
@@ -768,7 +760,6 @@ class SingleHitResult:
...
@@ -768,7 +760,6 @@ class SingleHitResult:
def
_prefilter_hit
(
def
_prefilter_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
...
@@ -789,17 +780,14 @@ def _prefilter_hit(
...
@@ -789,17 +780,14 @@ def _prefilter_hit(
hit
=
hit
,
hit
=
hit
,
hit_pdb_code
=
hit_pdb_code
,
hit_pdb_code
=
hit_pdb_code
,
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
release_dates
=
release_dates
,
release_dates
=
release_dates
,
release_date_cutoff
=
max_template_date
,
release_date_cutoff
=
max_template_date
,
)
)
except
PrefilterError
as
e
:
except
PrefilterError
as
e
:
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
logging
.
info
(
msg
)
if
strict_error_check
and
isinstance
(
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
):
# In strict mode we treat some prefilter cases as errors.
# In strict mode we treat some prefilter cases as errors.
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
...
@@ -808,9 +796,16 @@ def _prefilter_hit(
...
@@ -808,9 +796,16 @@ def _prefilter_hit(
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
_read_file
(
path
):
with
open
(
path
,
'r'
)
as
f
:
file_data
=
f
.
read
()
return
file_data
def
_process_single_hit
(
def
_process_single_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
...
@@ -847,9 +842,9 @@ def _process_single_hit(
...
@@ -847,9 +842,9 @@ def _process_single_hit(
query_sequence
,
query_sequence
,
template_sequence
,
template_sequence
,
)
)
# Fail if we can't find the mmCIF file.
# Fail if we can't find the mmCIF file.
with
open
(
cif_path
,
"r"
)
as
cif_file
:
cif_string
=
_read_file
(
cif_path
)
cif_string
=
cif_file
.
read
()
parsing_result
=
mmcif_parsing
.
parse
(
parsing_result
=
mmcif_parsing
.
parse
(
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
...
@@ -882,7 +877,11 @@ def _process_single_hit(
...
@@ -882,7 +877,11 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
_zero_center_positions
=
_zero_center_positions
,
)
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
if
hit
.
sum_probs
is
None
:
features
[
"template_sum_probs"
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# mmCIF file, but the template features for the chain we want were still
...
@@ -903,7 +902,7 @@ def _process_single_hit(
...
@@ -903,7 +902,7 @@ def _process_single_hit(
%
(
%
(
hit_pdb_code
,
hit_pdb_code
,
hit_chain_id
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
sum_probs
if
hit
.
sum_probs
else
0.
,
hit
.
index
,
hit
.
index
,
str
(
e
),
str
(
e
),
parsing_result
.
errors
,
parsing_result
.
errors
,
...
@@ -920,7 +919,7 @@ def _process_single_hit(
...
@@ -920,7 +919,7 @@ def _process_single_hit(
%
(
%
(
hit_pdb_code
,
hit_pdb_code
,
hit_chain_id
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
sum_probs
if
hit
.
sum_probs
else
0.
,
hit
.
index
,
hit
.
index
,
str
(
e
),
str
(
e
),
parsing_result
.
errors
,
parsing_result
.
errors
,
...
@@ -986,8 +985,8 @@ class TemplateSearchResult:
...
@@ -986,8 +985,8 @@ class TemplateSearchResult:
warnings
:
Sequence
[
str
]
warnings
:
Sequence
[
str
]
class
TemplateHitFeaturizer
:
class
TemplateHitFeaturizer
(
abc
.
ABC
)
:
"""A class for turning
hhr hits to
template features."""
"""A
n abstract base
class for turning template
hits to
features."""
def
__init__
(
def
__init__
(
self
,
self
,
mmcif_dir
:
str
,
mmcif_dir
:
str
,
...
@@ -1036,7 +1035,7 @@ class TemplateHitFeaturizer:
...
@@ -1036,7 +1035,7 @@ class TemplateHitFeaturizer:
raise
ValueError
(
raise
ValueError
(
"max_template_date must be set and have format YYYY-MM-DD."
"max_template_date must be set and have format YYYY-MM-DD."
)
)
self
.
max_hits
=
max_hits
self
.
_
max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
self
.
_strict_error_check
=
strict_error_check
...
@@ -1059,31 +1058,29 @@ class TemplateHitFeaturizer:
...
@@ -1059,31 +1058,29 @@ class TemplateHitFeaturizer:
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
self
.
_zero_center_positions
=
_zero_center_positions
@
abc
.
abstractmethod
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
""" Computes the templates for a given query sequence """
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
def
get_templates
(
self
,
self
,
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Searching for template for: %s"
,
query_
pdb_cod
e
)
logging
.
info
(
"Searching for template for: %s"
,
query_
sequenc
e
)
template_features
=
{}
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
template_features
[
template_feature_name
]
=
[]
# Always use a max_template_date. Set to query_release_date minus 60 days
already_seen
=
set
()
# if that's earlier.
template_cutoff_date
=
self
.
_max_template_date
if
query_release_date
:
delta
=
datetime
.
timedelta
(
days
=
60
)
if
query_release_date
-
delta
<
template_cutoff_date
:
template_cutoff_date
=
query_release_date
-
delta
assert
template_cutoff_date
<
query_release_date
assert
template_cutoff_date
<=
self
.
_max_template_date
num_hits
=
0
errors
=
[]
errors
=
[]
warnings
=
[]
warnings
=
[]
...
@@ -1091,9 +1088,8 @@ class TemplateHitFeaturizer:
...
@@ -1091,9 +1088,8 @@ class TemplateHitFeaturizer:
for
hit
in
hits
:
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
hit
=
hit
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
...
@@ -1119,17 +1115,16 @@ class TemplateHitFeaturizer:
...
@@ -1119,17 +1115,16 @@ class TemplateHitFeaturizer:
for
i
in
idx
:
for
i
in
idx
:
# We got all the templates we wanted, stop processing hits.
# We got all the templates we wanted, stop processing hits.
if
num_hits
>=
self
.
max_hits
:
if
len
(
already_seen
)
>=
self
.
_
max_hits
:
break
break
hit
=
filtered
[
i
]
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
...
@@ -1153,22 +1148,152 @@ class TemplateHitFeaturizer:
...
@@ -1153,22 +1148,152 @@ class TemplateHitFeaturizer:
result
.
warning
,
result
.
warning
,
)
)
else
:
else
:
# Increment the hit counter, since we got features out of this hit.
already_seen_key
=
result
.
features
[
"template_sequence"
]
num_hits
+=
1
if
(
already_seen_key
in
already_seen
):
continue
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
template_features
[
k
].
append
(
result
.
features
[
k
])
for
name
in
template_features
:
if
already_seen
:
if
num_hits
>
0
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
else
:
# Make sure the feature has correct dtype even if empty.
num_res
=
len
(
query_sequence
)
template_features
[
name
]
=
np
.
array
(
# Construct a default template with all zeros.
[],
dtype
=
TEMPLATE_FEATURES
[
name
]
template_features
=
{
)
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
logging
.
info
(
"Searching for template for: %s"
,
query_sequence
)
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
already_seen
=
set
()
errors
=
[]
warnings
=
[]
# DISCREPANCY: This filtering scheme that saves time
filtered
=
[]
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
)
if
prefilter_result
.
error
:
errors
.
append
(
prefilter_result
.
error
)
if
prefilter_result
.
warning
:
warnings
.
append
(
prefilter_result
.
warning
)
if
prefilter_result
.
valid
:
filtered
.
append
(
hit
)
filtered
=
list
(
sorted
(
filtered
,
key
=
lambda
x
:
x
.
sum_probs
if
x
.
sum_probs
else
0.
,
reverse
=
True
)
)
idx
=
list
(
range
(
len
(
filtered
)))
if
(
self
.
_shuffle_top_k_prefiltered
):
stk
=
self
.
_shuffle_top_k_prefiltered
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
debug
(
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
,
)
openfold/data/tools/hhblits.py
View file @
56d5e39c
...
@@ -18,7 +18,7 @@ import glob
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -99,9 +99,9 @@ class HHBlits:
...
@@ -99,9 +99,9 @@ class HHBlits:
self
.
p
=
p
self
.
p
=
p
self
.
z
=
z
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
def
query
(
self
,
input_fasta_path
:
str
)
->
List
[
Mapping
[
str
,
Any
]
]
:
"""Queries the database using HHblits."""
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
db_cmd
=
[]
db_cmd
=
[]
...
@@ -172,4 +172,4 @@ class HHBlits:
...
@@ -172,4 +172,4 @@ class HHBlits:
n_iter
=
self
.
n_iter
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
,
e_value
=
self
.
e_value
,
)
)
return
raw_output
return
[
raw_output
]
openfold/data/tools/hhsearch.py
View file @
56d5e39c
...
@@ -18,8 +18,9 @@ import glob
...
@@ -18,8 +18,9 @@ import glob
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
,
Optional
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -62,11 +63,20 @@ class HHSearch:
...
@@ -62,11 +63,20 @@ class HHSearch:
f
"Could not find HHsearch database
{
database_path
}
"
f
"Could not find HHsearch database
{
database_path
}
"
)
)
def
query
(
self
,
a3m
:
str
)
->
str
:
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
hhr_path
=
os
.
path
.
join
(
output_dir
,
"hhsearch_output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
f
.
write
(
a3m
)
...
@@ -104,3 +114,12 @@ class HHSearch:
...
@@ -104,3 +114,12 @@ class HHSearch:
with
open
(
hhr_path
)
as
f
:
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
hhr
=
f
.
read
()
return
hhr
return
hhr
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool"""
del
input_sequence
# Used by hmmsearch but not needed for hhsearch
return
parsers
.
parse_hhr
(
output_string
)
openfold/data/tools/hmmbuild.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
openfold.data.tools
import
utils
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
openfold/data/tools/hmmsearch.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
openfold.data
import
parsers
from
openfold.data.tools
import
hmmbuild
from
openfold.data.tools
import
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
'Could not find hmmsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'sto'
@
property
def
input_format
(
self
)
->
str
:
return
'sto'
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
'hmm_output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
])
logging
.
info
(
'Launching sub-process %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
input_sequence
,
)
return
template_hits
openfold/data/tools/jackhmmer.py
View file @
56d5e39c
...
@@ -23,6 +23,7 @@ import subprocess
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
urllib
import
request
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -93,10 +94,13 @@ class Jackhmmer:
...
@@ -93,10 +94,13 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# The F1/F2/F3 are the expected proportion to pass each of the filtering
...
@@ -167,8 +171,11 @@ class Jackhmmer:
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
tbl
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
if
(
max_sequences
is
None
):
sto
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
raw_output
=
dict
(
sto
=
sto
,
sto
=
sto
,
...
@@ -180,10 +187,16 @@ class Jackhmmer:
...
@@ -180,10 +187,16 @@ class Jackhmmer:
return
raw_output
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
@@ -217,12 +230,20 @@ class Jackhmmer:
...
@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
# Run Jackhmmer with the chunk
future
.
result
()
future
.
result
()
chunked_output
.
append
(
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
)
# Remove the local copy of the chunk
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
future
=
next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
self
.
streaming_callback
(
i
)
return
chunked_output
return
chunked_output
openfold/data/tools/kalign.py
View file @
56d5e39c
...
@@ -72,7 +72,7 @@ class Kalign:
...
@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
)
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment