Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
56d5e39c
Commit
56d5e39c
authored
Jun 17, 2023
by
Geoffrey Yu
Browse files
Merge remote-tracking branch 'upstream/multimer' into multimer
parents
56b86074
51556d52
Changes
80
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2983 additions
and
284 deletions
+2983
-284
.gitignore
.gitignore
+2
-0
openfold/__init__.py
openfold/__init__.py
+1
-0
openfold/config.py
openfold/config.py
+287
-1
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+521
-153
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+70
-27
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+303
-0
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+21
-7
openfold/data/feature_processing_multimer.py
openfold/data/feature_processing_multimer.py
+244
-0
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+136
-0
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+17
-1
openfold/data/msa_identifiers.py
openfold/data/msa_identifiers.py
+91
-0
openfold/data/msa_pairing.py
openfold/data/msa_pairing.py
+483
-0
openfold/data/parsers.py
openfold/data/parsers.py
+290
-17
openfold/data/templates.py
openfold/data/templates.py
+187
-62
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+4
-4
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+23
-4
openfold/data/tools/hmmbuild.py
openfold/data/tools/hmmbuild.py
+137
-0
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+137
-0
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+28
-7
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+1
-1
No files found.
.gitignore
View file @
56d5e39c
.vscode/
.idea/
__pycache__/
*.egg-info
build
...
...
@@ -8,3 +9,4 @@ dist
data
openfold/resources/
tests/test_data/
openfold/__init__.py
View file @
56d5e39c
from
.
import
model
from
.
import
utils
from
.
import
data
from
.
import
np
from
.
import
resources
...
...
openfold/config.py
View file @
56d5e39c
import
re
import
copy
import
importlib
import
ml_collections
as
mlc
...
...
@@ -16,7 +17,7 @@ def enforce_config_constraints(config):
path
=
s
.
split
(
'.'
)
setting
=
config
for
p
in
path
:
setting
=
setting
[
p
]
setting
=
setting
.
get
(
p
)
return
setting
...
...
@@ -152,6 +153,48 @@ def model_config(
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
globals
.
bfloat16
=
True
c
.
globals
.
bfloat16_output
=
False
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_msa_clusters
=
252
c
.
data
.
predict
.
max_msa_clusters
=
252
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
model
.
evoformer_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
elif
name
==
'model_4_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
elif
name
==
'model_5_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
else
:
c
.
data
.
train
.
max_msa_clusters
=
508
c
.
data
.
predict
.
max_msa_clusters
=
508
c
.
data
.
train
.
max_extra_msa
=
2048
c
.
data
.
predict
.
max_extra_msa
=
2048
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
])
else
:
raise
ValueError
(
"Invalid model name"
)
...
...
@@ -380,6 +423,7 @@ config = mlc.ConfigDict(
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"is_multimer"
:
False
,
},
"model"
:
{
"_mask_trans"
:
False
,
...
...
@@ -423,6 +467,8 @@ config = mlc.ConfigDict(
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
False
,
"fuse_projection_weights"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"tune_chunk_size"
:
tune_chunk_size
,
"inf"
:
1e9
,
...
...
@@ -471,6 +517,8 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"fuse_projection_weights"
:
False
,
"clear_cache_between_blocks"
:
False
,
"tune_chunk_size"
:
tune_chunk_size
,
"inf"
:
1e9
,
...
...
@@ -493,6 +541,8 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
False
,
"fuse_projection_weights"
:
False
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"tune_chunk_size"
:
tune_chunk_size
,
...
...
@@ -585,6 +635,7 @@ config = mlc.ConfigDict(
"weight"
:
0.01
,
},
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
...
...
@@ -597,6 +648,7 @@ config = mlc.ConfigDict(
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
False
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.0
,
},
...
...
@@ -609,8 +661,242 @@ config = mlc.ConfigDict(
"weight"
:
0.
,
"enabled"
:
tm_enabled
,
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.
,
"eps"
:
eps
,
"enabled"
:
False
,
},
"eps"
:
eps
,
},
"ema"
:
{
"decay"
:
0.999
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance"
:
-
1
}
)
multimer_model_config_update
=
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"msa_dim"
:
49
,
#"num_msa": 508,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_out"
:
64
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
True
},
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"c_in"
:
25
,
"c_out"
:
c_e
,
#"num_extra_msa": 2048
},
"extra_msa_stack"
:
{
"c_m"
:
c_e
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
8
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
"enabled"
:
True
,
},
"evoformer_stack"
:
{
"c_m"
:
c_m
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
32
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"c_s"
:
c_s
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
48
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
"structure_module"
:
{
"c_s"
:
c_s
,
"c_z"
:
c_z
,
"c_ipa"
:
16
,
"c_resnet"
:
128
,
"no_heads_ipa"
:
12
,
"no_qk_points"
:
4
,
"no_v_points"
:
8
,
"dropout_rate"
:
0.1
,
"no_blocks"
:
8
,
"no_transition_layers"
:
1
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
"trans_scale_factor"
:
20
,
"epsilon"
:
eps
,
# 1e-12,
"inf"
:
1e5
,
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"ptm_weight"
:
0.2
,
"iptm_weight"
:
0.8
,
"enabled"
:
True
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
# 1e-8,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.0
,
},
"fape"
:
{
"intra_chain_backbone"
:
{
"clamp_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
},
"interface"
:
{
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
},
"sidechain"
:
{
"clamp_distance"
:
10.0
,
"length_scale"
:
10.0
,
"weight"
:
0.5
,
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
},
"plddt_loss"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"no_bins"
:
50
,
"eps"
:
eps
,
# 1e-10,
"weight"
:
0.01
,
},
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
1.0
,
},
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
True
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.03
,
# Not finetuning
},
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.1
,
"enabled"
:
True
,
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.05
,
"eps"
:
eps
,
"enabled"
:
True
,
},
"eps"
:
eps
,
},
"recycle_early_stop_tolerance"
:
0.5
}
openfold/data/data_pipeline.py
View file @
56d5e39c
...
...
@@ -14,25 +14,30 @@
# limitations under the License.
import
os
import
datetime
import
copy
import
collections
import
contextlib
import
dataclasses
from
multiprocessing
import
cpu_count
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
import
tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
import
numpy
as
np
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
feature_processing_multimer
from
openfold.data.templates
import
get_custom_template_features
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
empty_template_feats
(
n_res
)
->
FeatureDict
:
return
{
"template_aatype"
:
np
.
zeros
((
0
,
n_res
)).
astype
(
np
.
int64
),
"template_all_atom_positions"
:
"template_all_atom_positions"
:
np
.
zeros
((
0
,
n_res
,
37
,
3
)).
astype
(
np
.
float32
),
"template_sum_probs"
:
np
.
zeros
((
0
,
1
)).
astype
(
np
.
float32
),
"template_all_atom_mask"
:
np
.
zeros
((
0
,
n_res
,
37
)).
astype
(
np
.
float32
),
...
...
@@ -52,8 +57,6 @@ def make_template_features(
else
:
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
query_pdb_code
,
query_release_date
=
query_release_date
,
hits
=
hits_cat
,
)
template_features
=
templates_result
.
features
...
...
@@ -85,7 +88,7 @@ def unify_template_features(
assert
(
new_shape
[
1
]
==
n_res
)
new_shape
[
1
]
=
sum
(
seq_lens
)
new_array
=
np
.
zeros
(
new_shape
,
dtype
=
v
.
dtype
)
if
(
k
==
"template_aatype"
):
new_array
[...,
residue_constants
.
HHBLITS_AA_TO_ID
[
'-'
]]
=
1
...
...
@@ -171,13 +174,13 @@ def make_mmcif_features(
def
_aatype_to_str_sequence
(
aatype
):
return
''
.
join
([
residue_constants
.
restypes_with_x
[
aatype
[
i
]]
residue_constants
.
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
])
def
make_protein_features
(
protein_object
:
protein
.
Protein
,
protein_object
:
protein
.
Protein
,
description
:
str
,
_is_distillation
:
bool
=
False
,
)
->
FeatureDict
:
...
...
@@ -224,32 +227,35 @@ def make_pdb_features(
return
pdb_feats
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
],
)
->
FeatureDict
:
def
make_msa_features
(
msas
:
Sequence
[
parsers
.
Msa
])
->
FeatureDict
:
"""Constructs a feature dict of MSA features."""
if
not
msas
:
raise
ValueError
(
"At least one MSA must be provided."
)
int_msa
=
[]
deletion_matrix
=
[]
species_ids
=
[]
seen_sequences
=
set
()
for
msa_index
,
msa
in
enumerate
(
msas
):
if
not
msa
:
raise
ValueError
(
f
"MSA
{
msa_index
}
must contain at least one sequence."
)
for
sequence_index
,
sequence
in
enumerate
(
msa
):
for
sequence_index
,
sequence
in
enumerate
(
msa
.
sequences
):
if
sequence
in
seen_sequences
:
continue
seen_sequences
.
add
(
sequence
)
int_msa
.
append
(
[
residue_constants
.
HHBLITS_AA_TO_ID
[
res
]
for
res
in
sequence
]
)
deletion_matrix
.
append
(
deletion_matrices
[
msa_index
][
sequence_index
])
num_res
=
len
(
msas
[
0
][
0
])
deletion_matrix
.
append
(
msa
.
deletion_matrix
[
sequence_index
])
identifiers
=
msa_identifiers
.
get_identifiers
(
msa
.
descriptions
[
sequence_index
]
)
species_ids
.
append
(
identifiers
.
species_id
.
encode
(
'utf-8'
))
num_res
=
len
(
msas
[
0
].
sequences
[
0
])
num_alignments
=
len
(
int_msa
)
features
=
{}
features
[
"deletion_matrix_int"
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
...
...
@@ -257,9 +263,29 @@ def make_msa_features(
features
[
"num_alignments"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object_
)
return
features
def
run_msa_tool
(
msa_runner
,
fasta_path
:
str
,
msa_out_path
:
str
,
msa_format
:
str
,
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
else
:
result
=
msa_runner
.
query
(
fasta_path
)[
0
]
assert
msa_out_path
.
split
(
'.'
)[
-
1
]
==
msa_format
with
open
(
msa_out_path
,
"w"
)
as
f
:
f
.
write
(
result
[
msa_format
])
return
result
def
make_sequence_features_with_custom_template
(
sequence
:
str
,
mmcif_path
:
str
,
...
...
@@ -277,10 +303,11 @@ def make_sequence_features_with_custom_template(
num_res
=
num_res
,
)
msa_data
=
[[
sequence
]]
deletion_matrix
=
[[[
0
for
_
in
sequence
]]]
msa_data
=
[
sequence
]
deletion_matrix
=
[[
0
for
_
in
sequence
]]
msa_data_obj
=
parsers
.
Msa
(
sequences
=
msa_data
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
None
)
msa_features
=
make_msa_features
(
msa_data
,
deletion_matrix
)
msa_features
=
make_msa_features
(
[
msa_data
_obj
]
)
template_features
=
get_custom_template_features
(
mmcif_path
=
mmcif_path
,
query_sequence
=
sequence
,
...
...
@@ -295,22 +322,25 @@ def make_sequence_features_with_custom_template(
**
template_features
.
features
}
class
AlignmentRunner
:
"""Runs alignment tools and saves the results"""
def
__init__
(
self
,
jackhmmer_binary_path
:
Optional
[
str
]
=
None
,
hhblits_binary_path
:
Optional
[
str
]
=
None
,
hhsearch_binary_path
:
Optional
[
str
]
=
None
,
uniref90_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniref30_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
pdb70_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
"""
Args:
...
...
@@ -318,8 +348,6 @@ class AlignmentRunner:
Path to jackhmmer binary
hhblits_binary_path:
Path to hhblits binary
hhsearch_binary_path:
Path to hhsearch binary
uniref90_database_path:
Path to uniref90 database. If provided, jackhmmer_binary_path
must also be provided
...
...
@@ -328,16 +356,17 @@ class AlignmentRunner:
must also be provided
bfd_database_path:
Path to BFD database. Depending on the value of use_small_bfd,
one of hhblits_binary_path or jackhmmer_binary_path must be
one of hhblits_binary_path or jackhmmer_binary_path must be
provided.
uniref30_database_path:
Path to uniref30. Searched alongside BFD if use_small_bfd is
false.
uniclust30_database_path:
Path to uniclust30. Searched alongside BFD if use_small_bfd is
Path to uniclust30. Searched alongside BFD if use_small_bfd is
false.
pdb70_database_path:
Path to pdb70 database.
use_small_bfd:
Whether to search the BFD database alone with jackhmmer or
in conjunction with uniclust30 with hhblits.
Whether to search the BFD database alone with jackhmmer or
in conjunction with
uniref30/
uniclust30 with hhblits.
no_cpus:
The number of CPUs available for alignment. By default, all
CPUs are used.
...
...
@@ -345,6 +374,8 @@ class AlignmentRunner:
Max number of uniref hits
mgnify_max_hits:
Max number of mgnify hits
uniprot_max_hits:
Max number of uniprot hits
"""
db_map
=
{
"jackhmmer"
:
{
...
...
@@ -353,6 +384,7 @@ class AlignmentRunner:
uniref90_database_path
,
mgnify_database_path
,
bfd_database_path
if
use_small_bfd
else
None
,
uniprot_database_path
,
],
},
"hhblits"
:
{
...
...
@@ -361,12 +393,6 @@ class AlignmentRunner:
bfd_database_path
if
not
use_small_bfd
else
None
,
],
},
"hhsearch"
:
{
"binary"
:
hhsearch_binary_path
,
"dbs"
:
[
pdb70_database_path
,
],
},
}
for
name
,
dic
in
db_map
.
items
():
...
...
@@ -376,22 +402,16 @@ class AlignmentRunner:
f
"
{
name
}
DBs provided but
{
name
}
binary is None"
)
if
(
not
all
([
x
is
None
for
x
in
db_map
[
"hhsearch"
][
"dbs"
]])
and
uniref90_database_path
is
None
):
raise
ValueError
(
"""uniref90_database_path must be specified in order to perform
template search"""
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniprot_max_hits
=
uniprot_max_hits
self
.
use_small_bfd
=
use_small_bfd
if
(
no_cpus
is
None
):
no_cpus
=
cpu_count
()
self
.
jackhmmer_uniref90_runner
=
None
if
(
jackhmmer_binary_path
is
not
None
and
if
(
jackhmmer_binary_path
is
not
None
and
uniref90_database_path
is
not
None
):
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
...
...
@@ -399,9 +419,9 @@ class AlignmentRunner:
database_path
=
uniref90_database_path
,
n_cpu
=
no_cpus
,
)
self
.
jackhmmer_small_bfd_runner
=
None
self
.
hhblits_bfd_uniclust_runner
=
None
self
.
hhblits_bfd_uni
ref
clust_runner
=
None
if
(
bfd_database_path
is
not
None
):
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
...
...
@@ -411,9 +431,11 @@ class AlignmentRunner:
)
else
:
dbs
=
[
bfd_database_path
]
if
(
uniclust30_database_path
is
not
None
):
if
(
uniref30_database_path
is
not
None
):
dbs
.
append
(
uniref30_database_path
)
if
(
uniclust30_database_path
is
not
None
):
dbs
.
append
(
uniclust30_database_path
)
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
self
.
hhblits_bfd_uni
ref
clust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
dbs
,
n_cpu
=
no_cpus
,
...
...
@@ -427,14 +449,22 @@ class AlignmentRunner:
n_cpu
=
no_cpus
,
)
self
.
hhsearch_pdb70_runner
=
None
if
(
pdb70_database_path
is
not
None
):
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
],
n_cpu
=
no_cpus
,
self
.
_uniprot_msa_runner
=
None
if
(
uniprot_database_path
is
not
None
):
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
)
if
(
template_searcher
is
not
None
and
self
.
jackhmmer_uniref90_runner
is
None
):
raise
ValueError
(
"Uniref90 runner must be specified to run template search"
)
self
.
template_searcher
=
template_searcher
def
run
(
self
,
fasta_path
:
str
,
...
...
@@ -442,52 +472,226 @@ class AlignmentRunner:
):
"""Runs alignment tools on a sequence"""
if
(
self
.
jackhmmer_uniref90_runner
is
not
None
):
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
"sto"
],
max_sequences
=
self
.
uniref_max_hits
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"uniref90_hits.sto"
)
jackhmmer_uniref90_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniref90_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniref_max_hits
,
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"uniref90_hits.a3m"
)
with
open
(
uniref90_out_path
,
"w"
)
as
f
:
f
.
write
(
uniref90_msa_as_a3m
)
if
(
self
.
hhsearch_pdb70_runner
is
not
None
):
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
"pdb70_hits.hhr"
)
with
open
(
pdb70_out_path
,
"w"
)
as
f
:
f
.
write
(
hhsearch_result
)
template_msa
=
jackhmmer_uniref90_result
[
"sto"
]
template_msa
=
parsers
.
deduplicate_stockholm_msa
(
template_msa
)
template_msa
=
parsers
.
remove_empty_columns_from_stockholm_msa
(
template_msa
)
if
(
self
.
template_searcher
is
not
None
):
if
(
self
.
template_searcher
.
input_format
==
"sto"
):
pdb_templates_result
=
self
.
template_searcher
.
query
(
template_msa
,
output_dir
=
output_dir
)
elif
(
self
.
template_searcher
.
input_format
==
"a3m"
):
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
template_msa
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
uniref90_msa_as_a3m
,
output_dir
=
output_dir
)
else
:
fmt
=
self
.
template_searcher
.
input_format
raise
ValueError
(
f
"Unrecognized template input format:
{
fmt
}
"
)
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_mgnify_result
[
"sto"
],
max_sequences
=
self
.
mgnify_max_hits
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.sto"
)
jackhmmer_mgnify_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
mgnify_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
mgnify_max_hits
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.a3m"
)
with
open
(
mgnify_out_path
,
"w"
)
as
f
:
f
.
write
(
mgnify_msa_as_a3m
)
if
(
self
.
use_small_bfd
and
self
.
jackhmmer_small_bfd_runner
is
not
None
):
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
with
open
(
bfd_out_path
,
"w"
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
"sto"
])
elif
(
self
.
hhblits_bfd_uniclust_runner
is
not
None
):
hhblits_bfd_uniclust_result
=
(
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
jackhmmer_small_bfd_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_small_bfd_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"sto"
,
)
elif
(
self
.
hhblits_bfd_unirefclust_runner
is
not
None
):
uni_name
=
"uni"
for
db_name
in
self
.
hhblits_bfd_unirefclust_runner
.
databases
:
if
"uniref"
in
db_name
.
lower
():
uni_name
=
f
"
{
uni_name
}
ref"
elif
"uniclust"
in
db_name
.
lower
():
uni_name
=
f
"
{
uni_name
}
clust"
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
f
"bfd_
{
uni_name
}
_hits.a3m"
)
hhblits_bfd_unirefclust_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_unirefclust_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"a3m"
,
)
if
(
self
.
jackhmmer_uniprot_runner
is
not
None
):
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
result
=
run_msa_tool
(
self
.
jackhmmer_uniprot_runner
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniprot_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
)
if
output_dir
is
not
None
:
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
with
open
(
bfd_out_path
,
"w"
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
_FastaChain
:
sequence
:
str
description
:
str
def
_make_chain_id_map
(
sequences
:
Sequence
[
str
],
descriptions
:
Sequence
[
str
],
)
->
Mapping
[
str
,
_FastaChain
]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if
len
(
sequences
)
!=
len
(
descriptions
):
raise
ValueError
(
'sequences and descriptions must have equal length. '
f
'Got
{
len
(
sequences
)
}
!=
{
len
(
descriptions
)
}
.'
)
if
len
(
sequences
)
>
protein
.
PDB_MAX_CHAINS
:
raise
ValueError
(
'Cannot process more chains than the PDB format supports. '
f
'Got
{
len
(
sequences
)
}
chains.'
)
chain_id_map
=
{}
for
chain_id
,
sequence
,
description
in
zip
(
protein
.
PDB_CHAIN_IDS
,
sequences
,
descriptions
):
chain_id_map
[
chain_id
]
=
_FastaChain
(
sequence
=
sequence
,
description
=
description
)
return
chain_id_map
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
convert_monomer_features
(
monomer_features
:
FeatureDict
,
chain_id
:
str
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
for
feature_name
,
feature
in
monomer_features
.
items
():
if
feature_name
in
unnecessary_leading_dim_feats
:
# asarray ensures it's a np.ndarray.
feature
=
np
.
asarray
(
feature
[
0
],
dtype
=
feature
.
dtype
)
elif
feature_name
==
'aatype'
:
# The multimer model performs the one-hot operation itself.
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
elif
feature_name
==
'template_aatype'
:
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature
=
np
.
take
(
new_order_list
,
feature
.
astype
(
np
.
int32
),
axis
=
0
)
elif
feature_name
==
'template_all_atom_masks'
:
feature_name
=
'template_all_atom_mask'
converted
[
feature_name
]
=
feature
return
converted
def
int_id_to_str_id
(
num
:
int
)
->
str
:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if
num
<=
0
:
raise
ValueError
(
f
'Only positive integers allowed, got
{
num
}
.'
)
num
=
num
-
1
# 1-based indexing.
output
=
[]
while
num
>=
0
:
output
.
append
(
chr
(
num
%
26
+
ord
(
'A'
)))
num
=
num
//
26
-
1
return
''
.
join
(
output
)
def
add_assembly_features
(
all_chain_features
:
MutableMapping
[
str
,
FeatureDict
],
)
->
MutableMapping
[
str
,
FeatureDict
]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id
=
{}
grouped_chains
=
collections
.
defaultdict
(
list
)
for
chain_id
,
chain_features
in
all_chain_features
.
items
():
seq
=
str
(
chain_features
[
'sequence'
])
if
seq
not
in
seq_to_entity_id
:
seq_to_entity_id
[
seq
]
=
len
(
seq_to_entity_id
)
+
1
grouped_chains
[
seq_to_entity_id
[
seq
]].
append
(
chain_features
)
new_all_chain_features
=
{}
chain_id
=
1
for
entity_id
,
group_chain_features
in
grouped_chains
.
items
():
for
sym_id
,
chain_features
in
enumerate
(
group_chain_features
,
start
=
1
):
new_all_chain_features
[
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
seq_length
=
chain_features
[
'seq_length'
]
chain_features
[
'asym_id'
]
=
(
chain_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_features
[
'sym_id'
]
=
(
sym_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_features
[
'entity_id'
]
=
(
entity_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_id
+=
1
return
new_all_chain_features
def
pad_msa
(
np_example
,
min_num_seq
):
np_example
=
dict
(
np_example
)
num_seq
=
np_example
[
'msa'
].
shape
[
0
]
if
num_seq
<
min_num_seq
:
for
feat
in
(
'msa'
,
'deletion_matrix'
,
'bert_mask'
,
'msa_mask'
):
np_example
[
feat
]
=
np
.
pad
(
np_example
[
feat
],
((
0
,
min_num_seq
-
num_seq
),
(
0
,
0
)))
np_example
[
'cluster_bias_mask'
]
=
np
.
pad
(
np_example
[
'cluster_bias_mask'
],
((
0
,
min_num_seq
-
num_seq
),))
return
np_example
class
DataPipeline
:
...
...
@@ -503,7 +707,7 @@ class DataPipeline:
alignment_dir
:
str
,
alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
msa_data
=
{}
if
(
alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
alignment_index
[
"db"
]),
"rb"
)
...
...
@@ -513,49 +717,46 @@ class DataPipeline:
return
msa
for
(
name
,
start
,
size
)
in
alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)
[
-
1
]
filename
,
ext
=
os
.
path
.
splitext
(
name
)
if
(
ext
==
".a3m"
):
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
msa
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
msa
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
))
else
:
continue
msa_data
[
name
]
=
dat
a
msa_data
[
name
]
=
ms
a
fp
.
close
()
else
:
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)
[
-
1
]
filename
,
ext
=
os
.
path
.
splitext
(
f
)
if
(
ext
==
".a3m"
):
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
fp
.
read
())
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
msa
=
parsers
.
parse_stockholm
(
fp
.
read
()
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
continue
msa_data
[
f
]
=
dat
a
msa_data
[
f
]
=
ms
a
return
msa_data
def
_parse_template_hits
(
def
_parse_template_hit
_file
s
(
self
,
alignment_dir
:
str
,
input_sequence
:
str
,
alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
...
...
@@ -572,6 +773,12 @@ class DataPipeline:
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
elif
(
name
==
"hmmsearch_output.sto"
):
hits
=
parsers
.
parse_hmmsearch_sto
(
read_template
(
start
,
size
),
input_sequence
,
)
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
...
...
@@ -583,9 +790,47 @@ class DataPipeline:
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
elif
(
f
==
"hmm_output.sto"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hmmsearch_sto
(
fp
.
read
(),
input_sequence
,
)
all_hits
[
f
]
=
hits
return
all_hits
def
_parse_template_hits
(
self
,
alignment_dir
:
str
,
alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
if
(
alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
alignment_index
[
"db"
]),
'rb'
)
def
read_template
(
start
,
size
):
fp
.
seek
(
start
)
return
fp
.
read
(
size
).
decode
(
"utf-8"
)
for
(
name
,
start
,
size
)
in
alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)[
-
1
]
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
fp
.
close
()
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)[
-
1
]
if
(
ext
==
".hhr"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
def
_get_msas
(
self
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
...
...
@@ -600,16 +845,13 @@ class DataPipeline:
must be provided.
"""
)
msa_data
[
"dummy"
]
=
{
"msa"
:
[
input_sequence
],
"deletion_matrix"
:
[[
0
for
_
in
input_sequence
]],
}
msas
,
deletion_matrices
=
zip
(
*
[
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
deletion_matrix
=
[[
0
for
_
in
input_sequence
]]
msa_data
[
"dummy"
]
=
parsers
.
Msa
(
sequences
=
input_sequence
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
None
)
return
msas
,
deletion_matrices
return
list
(
msa_data
.
values
())
def
_process_msa_feats
(
self
,
...
...
@@ -617,12 +859,11 @@ class DataPipeline:
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
msas
,
deletion_matrices
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
input_sequence
,
alignment_index
)
msa_features
=
make_msa_features
(
msas
=
msas
,
deletion_matrices
=
deletion_matrices
,
msas
=
msas
)
return
msa_features
...
...
@@ -633,7 +874,7 @@ class DataPipeline:
alignment_dir
:
str
,
alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
...
...
@@ -645,7 +886,12 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
,
input_sequence
,
alignment_index
,
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -659,10 +905,10 @@ class DataPipeline:
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
alignment_index
)
return
{
**
sequence_features
,
**
msa_features
,
**
msa_features
,
**
template_features
}
...
...
@@ -689,14 +935,17 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
self
.
template_featurizer
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
)
msa_features
=
self
.
_process_msa_feats
(
alignment_dir
,
input_sequence
,
alignment_index
)
return
{
**
mmcif_feats
,
**
template_features
,
**
msa_features
}
...
...
@@ -727,15 +976,19 @@ class DataPipeline:
pdb_str
=
f
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_str
,
chain_id
)
input_sequence
=
_aatype_to_str_sequence
(
protein_object
.
aatype
)
input_sequence
=
_aatype_to_str_sequence
(
protein_object
.
aatype
)
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
pdb_path
))[
0
].
upper
()
pdb_feats
=
make_pdb_features
(
protein_object
,
description
,
protein_object
,
description
,
is_distillation
=
is_distillation
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -759,11 +1012,15 @@ class DataPipeline:
core_str
=
f
.
read
()
protein_object
=
protein
.
from_proteinnet_string
(
core_str
)
input_sequence
=
_aatype_to_str_sequence
(
protein_object
.
aatype
)
input_sequence
=
_aatype_to_str_sequence
(
protein_object
.
aatype
)
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_index
)
template_features
=
make_template_features
(
input_sequence
,
hits
,
...
...
@@ -775,10 +1032,10 @@ class DataPipeline:
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
def
process_multiseq_fasta
(
self
,
fasta_path
:
str
,
super_alignment_dir
:
str
,
ri_gap
:
int
=
200
,
)
->
FeatureDict
:
fasta_path
:
str
,
super_alignment_dir
:
str
,
ri_gap
:
int
=
200
,
)
->
FeatureDict
:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter (a.k.a. AlphaFold-Gap).
...
...
@@ -787,7 +1044,7 @@ class DataPipeline:
fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
# No whitespace allowed
input_descs
=
[
i
.
split
()[
0
]
for
i
in
input_descs
]
...
...
@@ -814,14 +1071,15 @@ class DataPipeline:
alignment_dir
=
os
.
path
.
join
(
super_alignment_dir
,
desc
)
msas
,
deletion_mats
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
seq
,
None
)
msa_list
.
append
(
msas
)
deletion_mat_list
.
append
(
deletion_mat
s
)
msa_list
.
append
(
[
m
.
sequences
for
m
in
msas
]
)
deletion_mat_list
.
append
(
[
m
.
deletion_mat
rix
for
m
in
msas
])
final_msa
=
[]
final_deletion_mat
=
[]
final_msa_obj
=
[]
msa_it
=
enumerate
(
zip
(
msa_list
,
deletion_mat_list
))
for
i
,
(
msas
,
deletion_mats
)
in
msa_it
:
prec
,
post
=
sum
(
seq_lens
[:
i
]),
sum
(
seq_lens
[
i
+
1
:])
...
...
@@ -829,18 +1087,19 @@ class DataPipeline:
[
prec
*
'-'
+
seq
+
post
*
'-'
for
seq
in
msa
]
for
msa
in
msas
]
deletion_mats
=
[
[
prec
*
[
0
]
+
dml
+
post
*
[
0
]
for
dml
in
deletion_mat
]
[
prec
*
[
0
]
+
dml
+
post
*
[
0
]
for
dml
in
deletion_mat
]
for
deletion_mat
in
deletion_mats
]
assert
(
len
(
msas
[
0
][
-
1
])
==
len
(
input_sequence
))
assert
(
len
(
msas
[
0
][
-
1
])
==
len
(
input_sequence
))
final_msa
.
extend
(
msas
)
final_deletion_mat
.
extend
(
deletion_mats
)
final_msa_obj
.
extend
([
parsers
.
Msa
(
sequences
=
msas
[
k
],
deletion_matrix
=
deletion_mats
[
k
],
descriptions
=
None
)
for
k
in
range
(
len
(
msas
))])
msa_features
=
make_msa_features
(
msas
=
final_msa
,
deletion_matrices
=
final_deletion_mat
,
msas
=
final_msa_obj
)
template_feature_list
=
[]
...
...
@@ -860,6 +1119,115 @@ class DataPipeline:
return
{
**
sequence_features
,
**
msa_features
,
**
msa_features
,
**
template_features
,
}
class
DataPipelineMultimer
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
monomer_data_pipeline
:
DataPipeline
,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
def
_process_single_chain
(
self
,
chain_id
:
str
,
sequence
:
str
,
description
:
str
,
chain_alignment_dir
:
str
,
is_homomer_or_monomer
:
bool
)
->
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>
{
chain_id
}
\n
{
sequence
}
\n
'
if
not
os
.
path
.
exists
(
chain_alignment_dir
):
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
fasta_path
=
chain_fasta_path
,
alignment_dir
=
chain_alignment_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_fasta_path
,
chain_alignment_dir
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
def
_all_seq_msa_features
(
self
,
fasta_path
,
alignment_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
)
feats
=
{
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
if
k
in
valid_feats
}
return
feats
def
process_fasta
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
"""Creates features."""
with
open
(
fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
all_chain_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
for
desc
,
seq
in
zip
(
input_descs
,
input_seqs
):
if
seq
in
sequence_features
:
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
sequence_features
[
seq
]
)
continue
chain_features
=
self
.
_process_single_chain
(
chain_id
=
desc
,
sequence
=
seq
,
description
=
desc
,
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
),
is_homomer_or_monomer
=
is_homomer_or_monomer
)
chain_features
=
convert_monomer_features
(
chain_features
,
chain_id
=
desc
)
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
\ No newline at end of file
openfold/data/data_transforms.py
View file @
56d5e39c
...
...
@@ -23,6 +23,9 @@ import torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
...
...
@@ -93,7 +96,7 @@ def fix_templates_aatype(protein):
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"aatype"
].
device
,
new_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"
template_
aatype"
].
device
,
).
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
...
...
@@ -439,13 +442,15 @@ def make_hhblits_profile(protein):
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
,
seed
):
"""Create data for BERT on raw MSA."""
device
=
protein
[
"msa"
].
device
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
(
[
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
,
device
=
protein
[
"aatype"
].
device
device
=
device
)
categorical_probs
=
(
...
...
@@ -465,11 +470,17 @@ def make_masked_msa(protein, config, replace_fraction):
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
categorical_probs
,
pad_shapes
,
value
=
mask_prob
,
)
sh
=
protein
[
"msa"
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
sample
=
torch
.
rand
(
sh
,
device
=
device
,
generator
=
g
)
mask_position
=
sample
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
...
...
@@ -662,7 +673,7 @@ def make_atom14_masks(protein):
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
,
device
=
"cpu"
),
batch
,
batch
,
np
.
ndarray
)
out
=
make_atom14_masks
(
batch
)
...
...
@@ -728,7 +739,7 @@ def make_atom14_positions(protein):
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.0
all_matrices
[
resname
]
=
renaming_matrix
renaming_matrices
=
torch
.
stack
(
[
all_matrices
[
restype
]
for
restype
in
restype_3
]
)
...
...
@@ -774,10 +785,14 @@ def make_atom14_positions(protein):
def
atom37_to_frames
(
protein
,
eps
=
1e-8
):
is_multimer
=
"asym_id"
in
protein
aatype
=
protein
[
"aatype"
]
all_atom_positions
=
protein
[
"all_atom_positions"
]
all_atom_mask
=
protein
[
"all_atom_mask"
]
if
is_multimer
:
all_atom_positions
=
Vec3Array
.
from_array
(
all_atom_positions
)
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
""
,
dtype
=
object
)
...
...
@@ -824,19 +839,37 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims
=
batch_dims
,
)
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
if
is_multimer
:
base_atom_pos
=
[
batched_gather
(
pos
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
1
]),
)
for
pos
in
all_atom_positions
]
base_atom_pos
=
Vec3Array
.
from_array
(
torch
.
stack
(
base_atom_pos
,
dim
=-
1
))
else
:
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
Rigid
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
if
is_multimer
:
point_on_neg_x_axis
=
base_atom_pos
[:,
:,
0
]
origin
=
base_atom_pos
[:,
:,
1
]
point_on_xy_plane
=
base_atom_pos
[:,
:,
2
]
gt_rotation
=
Rot3Array
.
from_two_vectors
(
origin
-
point_on_neg_x_axis
,
point_on_xy_plane
-
origin
)
gt_frames
=
Rigid3Array
(
gt_rotation
,
origin
)
else
:
gt_frames
=
Rigid
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
...
...
@@ -857,9 +890,13 @@ def atom37_to_frames(protein, eps=1e-8):
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
if
is_multimer
:
gt_frames
=
gt_frames
.
compose_rotation
(
Rot3Array
.
from_array
(
rots
))
else
:
rots
=
Rotation
(
rot_mats
=
rots
)
gt_frames
=
gt_frames
.
compose
(
Rigid
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
...
...
@@ -893,12 +930,18 @@ def atom37_to_frames(protein, eps=1e-8):
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
Rotation
(
rot_mats
=
residx_rigidgroup_ambiguity_rot
)
alt_gt_frames
=
gt_frames
.
compose
(
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
)
if
is_multimer
:
ambiguity_rot
=
Rot3Array
.
from_array
(
residx_rigidgroup_ambiguity_rot
)
# Create the alternative ground truth frames.
alt_gt_frames
=
gt_frames
.
compose_rotation
(
ambiguity_rot
)
else
:
residx_rigidgroup_ambiguity_rot
=
Rotation
(
rot_mats
=
residx_rigidgroup_ambiguity_rot
)
alt_gt_frames
=
gt_frames
.
compose
(
Rigid
(
residx_rigidgroup_ambiguity_rot
,
None
)
)
gt_frames_tensor
=
gt_frames
.
to_tensor_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_tensor_4x4
()
...
...
openfold/data/data_transforms_multimer.py
0 → 100644
View file @
56d5e39c
from
typing
import
Sequence
import
torch
from
openfold.data.data_transforms
import
curry1
from
openfold.utils.tensor_utils
import
masked_mean
def
gumbel_noise
(
shape
:
Sequence
[
int
],
device
:
torch
.
device
,
eps
=
1e-6
,
generator
=
None
,
)
->
torch
.
Tensor
:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise
=
torch
.
rand
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
,
generator
=
generator
)
gumbel
=
-
torch
.
log
(
-
torch
.
log
(
uniform_noise
+
eps
)
+
eps
)
return
gumbel
def
gumbel_max_sample
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
nn
.
functional
.
one_hot
(
torch
.
argmax
(
logits
+
z
,
dim
=-
1
),
logits
.
shape
[
-
1
],
)
def
gumbel_argsort_sample_idx
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
argsort
(
logits
+
z
,
dim
=-
1
,
descending
=
True
)
@
curry1
def
make_masked_msa
(
batch
,
config
,
replace_fraction
,
seed
,
eps
=
1e-6
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
Tensor
(
[
0.05
]
*
20
+
[
0.
,
0.
],
device
=
batch
[
'msa'
].
device
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
batch
[
'msa_profile'
]
+
config
.
same_prob
*
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
[
0
,
1
],
value
=
mask_prob
)
sh
=
batch
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
,
device
=
batch
[
'msa'
].
device
)
<
replace_fraction
mask_position
*=
batch
[
'msa_mask'
].
to
(
mask_position
.
dtype
)
logits
=
torch
.
log
(
categorical_probs
+
eps
)
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
bert_msa
=
gumbel_max_sample
(
logits
,
generator
=
g
)
bert_msa
=
torch
.
where
(
mask_position
,
torch
.
argmax
(
bert_msa
,
dim
=-
1
),
batch
[
'msa'
]
)
bert_msa
*=
batch
[
'msa_mask'
].
to
(
bert_msa
.
dtype
)
# Mix real and masked MSA.
if
'bert_mask'
in
batch
:
batch
[
'bert_mask'
]
*=
mask_position
.
to
(
torch
.
float32
)
else
:
batch
[
'bert_mask'
]
=
mask_position
.
to
(
torch
.
float32
)
batch
[
'true_msa'
]
=
batch
[
'msa'
]
batch
[
'msa'
]
=
bert_msa
return
batch
@
curry1
def
nearest_neighbor_clusters
(
batch
,
gap_agreement_weight
=
0.
):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device
=
batch
[
"msa_mask"
].
device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights
=
torch
.
Tensor
(
[
1.
]
*
21
+
[
gap_agreement_weight
]
+
[
0.
],
device
=
device
,
)
msa_mask
=
batch
[
'msa_mask'
]
msa_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
extra_mask
=
batch
[
'extra_msa_mask'
]
extra_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'extra_msa'
],
23
)
msa_one_hot_masked
=
msa_mask
[:,
:,
None
]
*
msa_one_hot
extra_one_hot_masked
=
extra_mask
[:,
:,
None
]
*
extra_one_hot
agreement
=
torch
.
einsum
(
'mrc, nrc->nm'
,
extra_one_hot_masked
,
weights
*
msa_one_hot_masked
)
cluster_assignment
=
torch
.
nn
.
functional
.
softmax
(
1e3
*
agreement
,
dim
=
0
)
cluster_assignment
*=
torch
.
einsum
(
'mr, nr->mn'
,
msa_mask
,
extra_mask
)
cluster_count
=
torch
.
sum
(
cluster_assignment
,
dim
=-
1
)
cluster_count
+=
1.
# We always include the sequence itself.
msa_sum
=
torch
.
einsum
(
'nm, mrc->nrc'
,
cluster_assignment
,
extra_one_hot_masked
)
msa_sum
+=
msa_one_hot_masked
cluster_profile
=
msa_sum
/
cluster_count
[:,
None
,
None
]
extra_deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
deletion_matrix
=
batch
[
'deletion_matrix'
]
del_sum
=
torch
.
einsum
(
'nm, mc->nc'
,
cluster_assignment
,
extra_mask
*
extra_deletion_matrix
)
del_sum
+=
deletion_matrix
# Original sequence.
cluster_deletion_mean
=
del_sum
/
cluster_count
[:,
None
]
batch
[
'cluster_profile'
]
=
cluster_profile
batch
[
'cluster_deletion_mean'
]
=
cluster_deletion_mean
return
batch
def
create_target_feat
(
batch
):
"""Create the target features"""
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
21
).
to
(
torch
.
float32
)
return
batch
def
create_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
device
=
batch
[
"msa"
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
deletion_matrix
=
batch
[
'deletion_matrix'
]
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
deletion_mean_value
=
(
torch
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
pi
)
)[...,
None
]
msa_feat
=
torch
.
cat
(
[
msa_1hot
,
has_deletion
,
deletion_value
,
batch
[
'cluster_profile'
],
deletion_mean_value
],
dim
=-
1
,
)
batch
[
"msa_feat"
]
=
msa_feat
return
batch
def
build_extra_msa_feat
(
batch
):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa
=
batch
[
'extra_msa'
]
deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
extra_msa
,
23
)
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
)
extra_msa_mask
=
batch
[
'extra_msa_mask'
]
catted
=
torch
.
cat
([
msa_1hot
,
has_deletion
,
deletion_value
],
dim
=-
1
)
return
catted
@
curry1
def
sample_msa
(
batch
,
max_seq
,
max_extra_msa_seq
,
seed
,
inf
=
1e6
):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
# Sample uniformly among sequences with at least one non-masked position.
logits
=
(
torch
.
clamp
(
torch
.
sum
(
batch
[
'msa_mask'
],
dim
=-
1
),
0.
,
1.
)
-
1.
)
*
inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if
'cluster_bias_mask'
not
in
batch
:
cluster_bias_mask
=
torch
.
nn
.
functional
.
pad
(
batch
[
'msa'
].
new_zeros
(
batch
[
'msa'
].
shape
[
0
]
-
1
),
(
1
,
0
),
value
=
1.
)
else
:
cluster_bias_mask
=
batch
[
'cluster_bias_mask'
]
logits
+=
cluster_bias_mask
*
inf
index_order
=
gumbel_argsort_sample_idx
(
logits
,
generator
=
g
)
sel_idx
=
index_order
[:
max_seq
]
extra_idx
=
index_order
[
max_seq
:][:
max_extra_msa_seq
]
for
k
in
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'bert_mask'
]:
if
k
in
batch
:
batch
[
'extra_'
+
k
]
=
batch
[
k
][
extra_idx
]
batch
[
k
]
=
batch
[
k
][
sel_idx
]
return
batch
def
make_msa_profile
(
batch
):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch
[
"msa_profile"
]
=
masked_mean
(
batch
[
'msa_mask'
][...,
None
],
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
),
dim
=-
3
,
)
return
batch
openfold/data/feature_pipeline.py
View file @
56d5e39c
...
...
@@ -20,7 +20,7 @@ import ml_collections
import
numpy
as
np
import
torch
from
openfold.data
import
input_pipeline
from
openfold.data
import
input_pipeline
,
input_pipeline_multimer
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
...
...
@@ -74,8 +74,10 @@ def np_example_to_features(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
is_multimer
:
bool
=
False
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
...
...
@@ -88,11 +90,18 @@ def np_example_to_features(
np_example
=
np_example
,
features
=
feature_names
)
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
if
(
not
is_multimer
):
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
if
mode
==
"train"
:
p
=
torch
.
rand
(
1
).
item
()
...
...
@@ -122,10 +131,15 @@ class FeaturePipeline:
def
process_features
(
self
,
raw_features
:
FeatureDict
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
is_multimer
:
bool
=
False
,
)
->
FeatureDict
:
if
(
is_multimer
and
mode
!=
"predict"
):
raise
ValueError
(
"Multimer mode is not currently trainable"
)
return
np_example_to_features
(
np_example
=
raw_features
,
config
=
self
.
config
,
mode
=
mode
,
is_multimer
=
is_multimer
,
)
openfold/data/feature_processing_multimer.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data pipeline."""
from
typing
import
Iterable
,
MutableMapping
,
List
,
Mapping
from
openfold.data
import
msa_pairing
from
openfold.np
import
residue_constants
import
numpy
as
np
# TODO: Move this into the config
REQUIRED_FEATURES
=
frozenset
({
'aatype'
,
'all_atom_mask'
,
'all_atom_positions'
,
'all_chains_entity_ids'
,
'all_crops_all_chains_mask'
,
'all_crops_all_chains_positions'
,
'all_crops_all_chains_residue_ids'
,
'assembly_num_chains'
,
'asym_id'
,
'bert_mask'
,
'cluster_bias_mask'
,
'deletion_matrix'
,
'deletion_mean'
,
'entity_id'
,
'entity_mask'
,
'mem_peak'
,
'msa'
,
'msa_mask'
,
'num_alignments'
,
'num_templates'
,
'queue_size'
,
'residue_index'
,
'resolution'
,
'seq_length'
,
'seq_mask'
,
'sym_id'
,
'template_aatype'
,
'template_all_atom_mask'
,
'template_all_atom_positions'
})
MAX_TEMPLATES
=
4
MSA_CROP_SIZE
=
2048
def
_is_homomer_or_monomer
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]])
->
bool
:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains
=
len
(
np
.
unique
(
np
.
concatenate
(
[
np
.
unique
(
chain
[
'entity_id'
][
chain
[
'entity_id'
]
>
0
])
for
chain
in
chains
])))
return
num_unique_chains
==
1
def
pair_and_merge
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features
(
all_chain_features
)
np_chains_list
=
list
(
all_chain_features
.
values
())
pair_msa_sequences
=
not
_is_homomer_or_monomer
(
np_chains_list
)
if
pair_msa_sequences
:
np_chains_list
=
msa_pairing
.
create_paired_features
(
chains
=
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
crop_chains
(
np_chains_list
,
msa_crop_size
=
MSA_CROP_SIZE
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
msa_pairing
.
merge_chain_features
(
np_chains_list
=
np_chains_list
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
MAX_TEMPLATES
)
np_example
=
process_final
(
np_example
)
return
np_example
def
crop_chains
(
chains_list
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains
=
[]
for
chain
in
chains_list
:
cropped_chain
=
_crop_single_chain
(
chain
,
msa_crop_size
=
msa_crop_size
,
pair_msa_sequences
=
pair_msa_sequences
,
max_templates
=
max_templates
)
cropped_chains
.
append
(
cropped_chain
)
return
cropped_chains
def
_crop_single_chain
(
chain
:
Mapping
[
str
,
np
.
ndarray
],
msa_crop_size
:
int
,
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Crops msa sequences to `msa_crop_size`."""
msa_size
=
chain
[
'num_alignments'
]
if
pair_msa_sequences
:
msa_size_all_seq
=
chain
[
'num_alignments_all_seq'
]
msa_crop_size_all_seq
=
np
.
minimum
(
msa_size_all_seq
,
msa_crop_size
//
2
)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq
=
chain
[
'msa_all_seq'
][:
msa_crop_size_all_seq
,
:]
num_non_gapped_pairs
=
np
.
sum
(
np
.
any
(
msa_all_seq
!=
msa_pairing
.
MSA_GAP_IDX
,
axis
=
1
))
num_non_gapped_pairs
=
np
.
minimum
(
num_non_gapped_pairs
,
msa_crop_size_all_seq
)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size
=
np
.
maximum
(
msa_crop_size
-
num_non_gapped_pairs
,
0
)
msa_crop_size
=
np
.
minimum
(
msa_size
,
max_msa_crop_size
)
else
:
msa_crop_size
=
np
.
minimum
(
msa_size
,
msa_crop_size
)
include_templates
=
'template_aatype'
in
chain
and
max_templates
if
include_templates
:
num_templates
=
chain
[
'template_aatype'
].
shape
[
0
]
templates_crop_size
=
np
.
minimum
(
num_templates
,
max_templates
)
for
k
in
chain
:
k_split
=
k
.
split
(
'_all_seq'
)[
0
]
if
k_split
in
msa_pairing
.
TEMPLATE_FEATURES
:
chain
[
k
]
=
chain
[
k
][:
templates_crop_size
,
:]
elif
k_split
in
msa_pairing
.
MSA_FEATURES
:
if
'_all_seq'
in
k
and
pair_msa_sequences
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size_all_seq
,
:]
else
:
chain
[
k
]
=
chain
[
k
][:
msa_crop_size
,
:]
chain
[
'num_alignments'
]
=
np
.
asarray
(
msa_crop_size
,
dtype
=
np
.
int32
)
if
include_templates
:
chain
[
'num_templates'
]
=
np
.
asarray
(
templates_crop_size
,
dtype
=
np
.
int32
)
if
pair_msa_sequences
:
chain
[
'num_alignments_all_seq'
]
=
np
.
asarray
(
msa_crop_size_all_seq
,
dtype
=
np
.
int32
)
return
chain
def
process_final
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example
=
_correct_msa_restypes
(
np_example
)
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_filter_features
(
np_example
)
return
np_example
def
_correct_msa_restypes
(
np_example
):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
np_example
[
'msa'
]
=
np
.
take
(
new_order_list
,
np_example
[
'msa'
],
axis
=
0
)
np_example
[
'msa'
]
=
np_example
[
'msa'
].
astype
(
np
.
int32
)
return
np_example
def
_make_seq_mask
(
np_example
):
np_example
[
'seq_mask'
]
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
return
np_example
def
_make_msa_mask
(
np_example
):
"""Mask features are all ones, but will later be zero-padded."""
np_example
[
'msa_mask'
]
=
np
.
ones_like
(
np_example
[
'msa'
],
dtype
=
np
.
float32
)
seq_mask
=
(
np_example
[
'entity_id'
]
>
0
).
astype
(
np
.
float32
)
np_example
[
'msa_mask'
]
*=
seq_mask
[
None
]
return
np_example
def
_filter_features
(
np_example
:
Mapping
[
str
,
np
.
ndarray
]
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Filters features of example to only those requested."""
return
{
k
:
v
for
(
k
,
v
)
in
np_example
.
items
()
if
k
in
REQUIRED_FEATURES
}
def
process_unmerged_features
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]
):
"""Postprocessing stage for per-chain features before merging."""
num_chains
=
len
(
all_chain_features
)
for
chain_features
in
all_chain_features
.
values
():
# Convert deletion matrices to float.
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
if
'deletion_matrix_int_all_seq'
in
chain_features
:
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_matrix'
],
axis
=
0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
chain_features
[
'aatype'
]]
chain_features
[
'all_atom_mask'
]
=
all_atom_mask
chain_features
[
'all_atom_positions'
]
=
np
.
zeros
(
list
(
all_atom_mask
.
shape
)
+
[
3
])
# Add assembly_num_chains.
chain_features
[
'assembly_num_chains'
]
=
np
.
asarray
(
num_chains
)
# Add entity_mask.
for
chain_features
in
all_chain_features
.
values
():
chain_features
[
'entity_mask'
]
=
(
chain_features
[
'entity_id'
]
!=
0
).
astype
(
np
.
int32
)
openfold/data/input_pipeline_multimer.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
from
openfold.data
import
(
data_transforms
,
data_transforms_multimer
,
)
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
data_transforms
.
make_atom14_masks
,
]
if
(
common_cfg
.
use_templates
):
transforms
.
extend
([
data_transforms
.
make_pseudo_beta
(
"template_"
),
])
return
transforms
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
mode_cfg
.
max_extra_msa
msa_seed
=
None
if
(
not
common_cfg
.
resample_msa_in_recycling
):
msa_seed
=
ensemble_seed
transforms
.
append
(
data_transforms_multimer
.
sample_msa
(
max_msa_clusters
,
max_extra_msa
,
seed
=
msa_seed
,
)
)
if
"masked_msa"
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms_multimer
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
,
seed
=
(
msa_seed
+
1
)
if
msa_seed
else
None
,
)
)
transforms
.
append
(
data_transforms_multimer
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms_multimer
.
create_msa_feat
)
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
()
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
,
)
fn
=
compose
(
fns
)
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
openfold/data/mmcif_parsing.py
View file @
56d5e39c
...
...
@@ -16,6 +16,7 @@
"""Parses the mmCIF file format."""
import
collections
import
dataclasses
import
functools
import
io
import
json
import
logging
...
...
@@ -173,6 +174,7 @@ def mmcif_loop_to_dict(
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
parse
(
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
...
...
@@ -346,7 +348,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
"resolution"
]
=
float
(
raw_resolution
)
except
ValueError
:
logging
.
info
(
logging
.
debug
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
...
...
@@ -474,6 +476,20 @@ def get_atom_coords(
pos
[
residue_constants
.
atom_order
[
"SD"
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
"SD"
]]
=
1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd
=
residue_constants
.
atom_order
[
'CD'
]
nh1
=
residue_constants
.
atom_order
[
'NH1'
]
nh2
=
residue_constants
.
atom_order
[
'NH2'
]
if
(
res
.
get_resname
()
==
'ARG'
and
all
(
mask
[
atom_index
]
for
atom_index
in
(
cd
,
nh1
,
nh2
))
and
(
np
.
linalg
.
norm
(
pos
[
nh1
]
-
pos
[
cd
])
>
np
.
linalg
.
norm
(
pos
[
nh2
]
-
pos
[
cd
]))
):
pos
[
nh1
],
pos
[
nh2
]
=
pos
[
nh2
].
copy
(),
pos
[
nh1
].
copy
()
mask
[
nh1
],
mask
[
nh2
]
=
mask
[
nh2
].
copy
(),
mask
[
nh1
].
copy
()
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
...
...
openfold/data/msa_identifiers.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import
dataclasses
import
re
from
typing
import
Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN
=
re
.
compile
(
r
"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
"""
,
re
.
VERBOSE
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Identifiers
:
species_id
:
str
=
''
def
_parse_sequence_identifier
(
msa_sequence_identifier
:
str
)
->
Identifiers
:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
if
matches
:
return
Identifiers
(
species_id
=
matches
.
group
(
'SpeciesIdentifier'
)
)
return
Identifiers
()
def
_extract_sequence_identifier
(
description
:
str
)
->
Optional
[
str
]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description
=
description
.
split
()
if
split_description
:
return
split_description
[
0
].
partition
(
'/'
)[
0
]
else
:
return
None
def
get_identifiers
(
description
:
str
)
->
Identifiers
:
"""Computes extra MSA features from the description."""
sequence_identifier
=
_extract_sequence_identifier
(
description
)
if
sequence_identifier
is
None
:
return
Identifiers
()
else
:
return
_parse_sequence_identifier
(
sequence_identifier
)
openfold/data/msa_pairing.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import
collections
import
functools
import
string
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Sequence
,
Mapping
import
numpy
as
np
import
pandas
as
pd
import
scipy.linalg
from
openfold.np
import
residue_constants
# TODO: This stuff should probably also be in a config
MSA_GAP_IDX
=
residue_constants
.
restypes_with_x_and_gap
.
index
(
'-'
)
SEQUENCE_GAP_CUTOFF
=
0.5
SEQUENCE_SIMILARITY_CUTOFF
=
0.9
MSA_PAD_VALUES
=
{
'msa_all_seq'
:
MSA_GAP_IDX
,
'msa_mask_all_seq'
:
1
,
'deletion_matrix_all_seq'
:
0
,
'deletion_matrix_int_all_seq'
:
0
,
'msa'
:
MSA_GAP_IDX
,
'msa_mask'
:
1
,
'deletion_matrix'
:
0
,
'deletion_matrix_int'
:
0
}
MSA_FEATURES
=
(
'msa'
,
'msa_mask'
,
'deletion_matrix'
,
'deletion_matrix_int'
)
SEQ_FEATURES
=
(
'residue_index'
,
'aatype'
,
'all_atom_positions'
,
'all_atom_mask'
,
'seq_mask'
,
'between_segment_residues'
,
'has_alt_locations'
,
'has_hetatoms'
,
'asym_id'
,
'entity_id'
,
'sym_id'
,
'entity_mask'
,
'deletion_mean'
,
'prediction_atom_mask'
,
'literature_positions'
,
'atom_indices_to_group_indices'
,
'rigid_group_default_frame'
)
TEMPLATE_FEATURES
=
(
'template_aatype'
,
'template_all_atom_positions'
,
'template_all_atom_mask'
)
CHAIN_FEATURES
=
(
'num_alignments'
,
'seq_length'
)
def
create_paired_features
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]],
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains
=
list
(
chains
)
chain_keys
=
chains
[
0
].
keys
()
if
len
(
chains
)
<
2
:
return
chains
else
:
updated_chains
=
[]
paired_chains_to_paired_row_indices
=
pair_sequences
(
chains
)
paired_rows
=
reorder_paired_rows
(
paired_chains_to_paired_row_indices
)
for
chain_num
,
chain
in
enumerate
(
chains
):
new_chain
=
{
k
:
v
for
k
,
v
in
chain
.
items
()
if
'_all_seq'
not
in
k
}
for
feature_name
in
chain_keys
:
if
feature_name
.
endswith
(
'_all_seq'
):
feats_padded
=
pad_features
(
chain
[
feature_name
],
feature_name
)
new_chain
[
feature_name
]
=
feats_padded
[
paired_rows
[:,
chain_num
]]
new_chain
[
'num_alignments_all_seq'
]
=
np
.
asarray
(
len
(
paired_rows
[:,
chain_num
]))
updated_chains
.
append
(
new_chain
)
return
updated_chains
def
pad_features
(
feature
:
np
.
ndarray
,
feature_name
:
str
)
->
np
.
ndarray
:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert
feature
.
dtype
!=
np
.
dtype
(
np
.
string_
)
if
feature_name
in
(
'msa_all_seq'
,
'msa_mask_all_seq'
,
'deletion_matrix_all_seq'
,
'deletion_matrix_int_all_seq'
):
num_res
=
feature
.
shape
[
1
]
padding
=
MSA_PAD_VALUES
[
feature_name
]
*
np
.
ones
([
1
,
num_res
],
feature
.
dtype
)
elif
feature_name
==
'msa_species_identifiers_all_seq'
:
padding
=
[
b
''
]
else
:
return
feature
feats_padded
=
np
.
concatenate
([
feature
,
padding
],
axis
=
0
)
return
feats_padded
def
_make_msa_df
(
chain_features
:
Mapping
[
str
,
np
.
ndarray
])
->
pd
.
DataFrame
:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa
=
chain_features
[
'msa_all_seq'
]
query_seq
=
chain_msa
[
0
]
per_seq_similarity
=
np
.
sum
(
query_seq
[
None
]
==
chain_msa
,
axis
=-
1
)
/
float
(
len
(
query_seq
))
per_seq_gap
=
np
.
sum
(
chain_msa
==
21
,
axis
=-
1
)
/
float
(
len
(
query_seq
))
msa_df
=
pd
.
DataFrame
({
'msa_species_identifiers'
:
chain_features
[
'msa_species_identifiers_all_seq'
],
'msa_row'
:
np
.
arange
(
len
(
chain_features
[
'msa_species_identifiers_all_seq'
])),
'msa_similarity'
:
per_seq_similarity
,
'gap'
:
per_seq_gap
})
return
msa_df
def
_create_species_dict
(
msa_df
:
pd
.
DataFrame
)
->
Dict
[
bytes
,
pd
.
DataFrame
]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup
=
{}
for
species
,
species_df
in
msa_df
.
groupby
(
'msa_species_identifiers'
):
species_lookup
[
species
]
=
species_df
return
species_lookup
def
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
]
)
->
List
[
List
[
int
]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows
=
[]
num_seqs
=
[
len
(
species_df
)
for
species_df
in
this_species_msa_dfs
if
species_df
is
not
None
]
take_num_seqs
=
np
.
min
(
num_seqs
)
sort_by_similarity
=
(
lambda
x
:
x
.
sort_values
(
'msa_similarity'
,
axis
=
0
,
ascending
=
False
))
for
species_df
in
this_species_msa_dfs
:
if
species_df
is
not
None
:
species_df_sorted
=
sort_by_similarity
(
species_df
)
msa_rows
=
species_df_sorted
.
msa_row
.
iloc
[:
take_num_seqs
].
values
else
:
msa_rows
=
[
-
1
]
*
take_num_seqs
# take the last 'padding' row
all_paired_msa_rows
.
append
(
msa_rows
)
all_paired_msa_rows
=
list
(
np
.
array
(
all_paired_msa_rows
).
transpose
())
return
all_paired_msa_rows
def
pair_sequences
(
examples
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
)
->
Dict
[
int
,
np
.
ndarray
]:
"""Returns indices for paired MSA sequences across chains."""
num_examples
=
len
(
examples
)
all_chain_species_dict
=
[]
common_species
=
set
()
for
chain_features
in
examples
:
msa_df
=
_make_msa_df
(
chain_features
)
species_dict
=
_create_species_dict
(
msa_df
)
all_chain_species_dict
.
append
(
species_dict
)
common_species
.
update
(
set
(
species_dict
))
common_species
=
sorted
(
common_species
)
common_species
.
remove
(
b
''
)
# Remove target sequence species.
all_paired_msa_rows
=
[
np
.
zeros
(
len
(
examples
),
int
)]
all_paired_msa_rows_dict
=
{
k
:
[]
for
k
in
range
(
num_examples
)}
all_paired_msa_rows_dict
[
num_examples
]
=
[
np
.
zeros
(
len
(
examples
),
int
)]
for
species
in
common_species
:
if
not
species
:
continue
this_species_msa_dfs
=
[]
species_dfs_present
=
0
for
species_dict
in
all_chain_species_dict
:
if
species
in
species_dict
:
this_species_msa_dfs
.
append
(
species_dict
[
species
])
species_dfs_present
+=
1
else
:
this_species_msa_dfs
.
append
(
None
)
# Skip species that are present in only one chain.
if
species_dfs_present
<=
1
:
continue
if
np
.
any
(
np
.
array
([
len
(
species_df
)
for
species_df
in
this_species_msa_dfs
if
isinstance
(
species_df
,
pd
.
DataFrame
)])
>
600
):
continue
paired_msa_rows
=
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
)
all_paired_msa_rows
.
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
[
species_dfs_present
].
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
=
{
num_examples
:
np
.
array
(
paired_msa_rows
)
for
num_examples
,
paired_msa_rows
in
all_paired_msa_rows_dict
.
items
()
}
return
all_paired_msa_rows_dict
def
reorder_paired_rows
(
all_paired_msa_rows_dict
:
Dict
[
int
,
np
.
ndarray
]
)
->
np
.
ndarray
:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows
=
[]
for
num_pairings
in
sorted
(
all_paired_msa_rows_dict
,
reverse
=
True
):
paired_rows
=
all_paired_msa_rows_dict
[
num_pairings
]
paired_rows_product
=
abs
(
np
.
array
([
np
.
prod
(
rows
)
for
rows
in
paired_rows
]))
paired_rows_sort_index
=
np
.
argsort
(
paired_rows_product
)
all_paired_msa_rows
.
extend
(
paired_rows
[
paired_rows_sort_index
])
return
np
.
array
(
all_paired_msa_rows
)
def
block_diag
(
*
arrs
:
np
.
ndarray
,
pad_value
:
float
=
0.0
)
->
np
.
ndarray
:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs
=
[
np
.
ones_like
(
x
)
for
x
in
arrs
]
off_diag_mask
=
1.0
-
scipy
.
linalg
.
block_diag
(
*
ones_arrs
)
diag
=
scipy
.
linalg
.
block_diag
(
*
arrs
)
diag
+=
(
off_diag_mask
*
pad_value
).
astype
(
diag
.
dtype
)
return
diag
def
_correct_post_merged_feats
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_chains_list
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Adds features that need to be computed/recomputed post merging."""
num_res
=
np_example
[
'aatype'
].
shape
[
0
]
np_example
[
'seq_length'
]
=
np
.
asarray
(
[
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
np_example
[
'num_alignments'
]
=
np
.
asarray
(
np_example
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
if
not
pair_msa_sequences
:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks
=
[]
for
chain
in
np_chains_list
:
mask
=
np
.
zeros
(
chain
[
'msa'
].
shape
[
0
])
mask
[
0
]
=
1
cluster_bias_masks
.
append
(
mask
)
np_example
[
'cluster_bias_mask'
]
=
np
.
concatenate
(
cluster_bias_masks
)
# Initialize Bert mask with masked out off diagonals.
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
np_example
[
'bert_mask'
]
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
else
:
np_example
[
'cluster_bias_mask'
]
=
np
.
zeros
(
np_example
[
'msa'
].
shape
[
0
])
np_example
[
'cluster_bias_mask'
][
0
]
=
1
# Initialize Bert mask with masked out off diagonals.
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_masks_all_seq
=
[
np
.
ones
(
x
[
'msa_all_seq'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_mask_block_diag
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
msa_mask_all_seq
=
np
.
concatenate
(
msa_masks_all_seq
,
axis
=
1
)
np_example
[
'bert_mask'
]
=
np
.
concatenate
(
[
msa_mask_all_seq
,
msa_mask_block_diag
],
axis
=
0
)
return
np_example
def
_pad_templates
(
chains
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
max_templates
:
int
)
->
Sequence
[
Mapping
[
str
,
np
.
ndarray
]]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for
chain
in
chains
:
for
k
,
v
in
chain
.
items
():
if
k
in
TEMPLATE_FEATURES
:
padding
=
np
.
zeros_like
(
v
.
shape
)
padding
[
0
]
=
max_templates
-
v
.
shape
[
0
]
padding
=
[(
0
,
p
)
for
p
in
padding
]
chain
[
k
]
=
np
.
pad
(
v
,
padding
,
mode
=
'constant'
)
return
chains
def
_merge_features_from_multiple_chains
(
chains
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example
=
{}
for
feature_name
in
chains
[
0
]:
feats
=
[
x
[
feature_name
]
for
x
in
chains
]
feature_name_split
=
feature_name
.
split
(
'_all_seq'
)[
0
]
if
feature_name_split
in
MSA_FEATURES
:
if
pair_msa_sequences
or
'_all_seq'
in
feature_name
:
merged_example
[
feature_name
]
=
np
.
concatenate
(
feats
,
axis
=
1
)
else
:
merged_example
[
feature_name
]
=
block_diag
(
*
feats
,
pad_value
=
MSA_PAD_VALUES
[
feature_name
])
elif
feature_name_split
in
SEQ_FEATURES
:
merged_example
[
feature_name
]
=
np
.
concatenate
(
feats
,
axis
=
0
)
elif
feature_name_split
in
TEMPLATE_FEATURES
:
merged_example
[
feature_name
]
=
np
.
concatenate
(
feats
,
axis
=
1
)
elif
feature_name_split
in
CHAIN_FEATURES
:
merged_example
[
feature_name
]
=
np
.
sum
(
x
for
x
in
feats
).
astype
(
np
.
int32
)
else
:
merged_example
[
feature_name
]
=
feats
[
0
]
return
merged_example
def
_merge_homomers_dense_msa
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]])
->
Sequence
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains
=
collections
.
defaultdict
(
list
)
for
chain
in
chains
:
entity_id
=
chain
[
'entity_id'
][
0
]
entity_chains
[
entity_id
].
append
(
chain
)
grouped_chains
=
[]
for
entity_id
in
sorted
(
entity_chains
):
chains
=
entity_chains
[
entity_id
]
grouped_chains
.
append
(
chains
)
chains
=
[
_merge_features_from_multiple_chains
(
chains
,
pair_msa_sequences
=
True
)
for
chains
in
grouped_chains
]
return
chains
def
_concatenate_paired_and_unpaired_features
(
example
:
Mapping
[
str
,
np
.
ndarray
])
->
Mapping
[
str
,
np
.
ndarray
]:
"""Merges paired and block-diagonalised features."""
features
=
MSA_FEATURES
for
feature_name
in
features
:
if
feature_name
in
example
:
feat
=
example
[
feature_name
]
feat_all_seq
=
example
[
feature_name
+
'_all_seq'
]
merged_feat
=
np
.
concatenate
([
feat_all_seq
,
feat
],
axis
=
0
)
example
[
feature_name
]
=
merged_feat
example
[
'num_alignments'
]
=
np
.
array
(
example
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
return
example
def
merge_chain_features
(
np_chains_list
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
,
max_templates
:
int
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list
=
_pad_templates
(
np_chains_list
,
max_templates
=
max_templates
)
np_chains_list
=
_merge_homomers_dense_msa
(
np_chains_list
)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example
=
_merge_features_from_multiple_chains
(
np_chains_list
,
pair_msa_sequences
=
False
)
if
pair_msa_sequences
:
np_example
=
_concatenate_paired_and_unpaired_features
(
np_example
)
np_example
=
_correct_post_merged_feats
(
np_example
=
np_example
,
np_chains_list
=
np_chains_list
,
pair_msa_sequences
=
pair_msa_sequences
)
return
np_example
def
deduplicate_unpaired_sequences
(
np_chains
:
List
[
Mapping
[
str
,
np
.
ndarray
]])
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names
=
np_chains
[
0
].
keys
()
msa_features
=
MSA_FEATURES
for
chain
in
np_chains
:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set
=
set
(
tuple
(
s
)
for
s
in
chain
[
'msa_all_seq'
])
keep_rows
=
[]
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for
row_num
,
seq
in
enumerate
(
chain
[
'msa'
]):
if
tuple
(
seq
)
not
in
sequence_set
:
keep_rows
.
append
(
row_num
)
for
feature_name
in
feature_names
:
if
feature_name
in
msa_features
:
chain
[
feature_name
]
=
chain
[
feature_name
][
keep_rows
]
chain
[
'num_alignments'
]
=
np
.
array
(
chain
[
'msa'
].
shape
[
0
],
dtype
=
np
.
int32
)
return
np_chains
openfold/data/parsers.py
View file @
56d5e39c
...
...
@@ -16,14 +16,43 @@
"""Functions for parsing various file formats."""
import
collections
import
dataclasses
import
itertools
import
re
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Msa
:
"""Class representing a parsed MSA file"""
sequences
:
Sequence
[
str
]
deletion_matrix
:
DeletionMatrix
descriptions
:
Optional
[
Sequence
[
str
]]
def
__post_init__
(
self
):
if
(
not
(
len
(
self
.
sequences
)
==
len
(
self
.
deletion_matrix
)
==
len
(
self
.
descriptions
)
)):
raise
ValueError
(
"All fields for an MSA must have the same length"
)
def
__len__
(
self
):
return
len
(
self
.
sequences
)
def
truncate
(
self
,
max_seqs
:
int
):
return
Msa
(
sequences
=
self
.
sequences
[:
max_seqs
],
deletion_matrix
=
self
.
deletion_matrix
[:
max_seqs
],
descriptions
=
self
.
descriptions
[:
max_seqs
],
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
"""Class representing a template hit."""
...
...
@@ -31,7 +60,7 @@ class TemplateHit:
index
:
int
name
:
str
aligned_cols
:
int
sum_probs
:
float
sum_probs
:
Optional
[
float
]
query
:
str
hit_sequence
:
str
indices_query
:
List
[
int
]
...
...
@@ -69,9 +98,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return
sequences
,
descriptions
def
parse_stockholm
(
stockholm_string
:
str
,
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
def
parse_stockholm
(
stockholm_string
:
str
)
->
Msa
:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
...
...
@@ -126,10 +153,14 @@ def parse_stockholm(
deletion_count
=
0
deletion_matrix
.
append
(
deletion_vec
)
return
msa
,
deletion_matrix
,
list
(
name_to_sequence
.
keys
())
return
Msa
(
sequences
=
msa
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
list
(
name_to_sequence
.
keys
())
)
def
parse_a3m
(
a3m_string
:
str
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
]
:
def
parse_a3m
(
a3m_string
:
str
)
->
Msa
:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
...
...
@@ -144,7 +175,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences
,
_
=
parse_fasta
(
a3m_string
)
sequences
,
descriptions
=
parse_fasta
(
a3m_string
)
deletion_matrix
=
[]
for
msa_sequence
in
sequences
:
deletion_vec
=
[]
...
...
@@ -160,7 +191,11 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table
=
str
.
maketrans
(
""
,
""
,
string
.
ascii_lowercase
)
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
return
aligned_sequences
,
deletion_matrix
return
Msa
(
sequences
=
aligned_sequences
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
descriptions
)
def
_convert_sto_seq_to_a3m
(
...
...
@@ -174,7 +209,9 @@ def _convert_sto_seq_to_a3m(
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
,
remove_first_row_gaps
:
bool
=
True
,
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
...
...
@@ -212,13 +249,19 @@ def convert_stockholm_to_a3m(
# Convert sto format to a3m line by line
a3m_sequences
=
{}
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
if
(
remove_first_row_gaps
):
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
""
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
)
)
# Dots are optional in a3m format and are commonly removed.
out_sequence
=
sto_sequence
.
replace
(
'.'
,
''
)
if
(
remove_first_row_gaps
):
out_sequence
=
''
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
out_sequence
)
)
a3m_sequences
[
seqname
]
=
out_sequence
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
...
...
@@ -227,6 +270,124 @@ def convert_stockholm_to_a3m(
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
def
_keep_line
(
line
:
str
,
seqnames
:
Set
[
str
])
->
bool
:
"""Function to decide which lines to keep."""
if
not
line
.
strip
():
return
True
if
line
.
strip
()
==
'//'
:
# End tag
return
True
if
line
.
startswith
(
'# STOCKHOLM'
):
# Start tag
return
True
if
line
.
startswith
(
'#=GC RF'
):
# Reference Annotation Line
return
True
if
line
[:
4
]
==
'#=GS'
:
# Description lines - keep if sequence in list.
_
,
seqname
,
_
=
line
.
split
(
maxsplit
=
2
)
return
seqname
in
seqnames
elif
line
.
startswith
(
'#'
):
# Other markup - filter out
return
False
else
:
# Alignment data - keep if sequence in list.
seqname
=
line
.
partition
(
' '
)[
0
]
return
seqname
in
seqnames
def
truncate_stockholm_msa
(
stockholm_msa_path
:
str
,
max_sequences
:
int
)
->
str
:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames
=
set
()
filtered_lines
=
[]
with
open
(
stockholm_msa_path
)
as
f
:
for
line
in
f
:
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname
=
line
.
partition
(
' '
)[
0
]
seqnames
.
add
(
seqname
)
if
len
(
seqnames
)
>=
max_sequences
:
break
f
.
seek
(
0
)
for
line
in
f
:
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
''
.
join
(
filtered_lines
)
def
remove_empty_columns_from_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines
=
{}
unprocessed_lines
=
{}
for
i
,
line
in
enumerate
(
stockholm_msa
.
splitlines
()):
if
line
.
startswith
(
'#=GC RF'
):
reference_annotation_i
=
i
reference_annotation_line
=
line
# Reached the end of this chunk of the alignment. Process chunk.
_
,
_
,
first_alignment
=
line
.
rpartition
(
' '
)
mask
=
[]
for
j
in
range
(
len
(
first_alignment
)):
for
_
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
if
alignment
[
j
]
!=
'-'
:
mask
.
append
(
True
)
break
else
:
# Every row contained a hyphen - empty column.
mask
.
append
(
False
)
# Add reference annotation for processing with mask.
unprocessed_lines
[
reference_annotation_i
]
=
reference_annotation_line
if
not
any
(
mask
):
# All columns were empty. Output empty lines for chunk.
for
line_index
in
unprocessed_lines
:
processed_lines
[
line_index
]
=
''
else
:
for
line_index
,
unprocessed_line
in
unprocessed_lines
.
items
():
prefix
,
_
,
alignment
=
unprocessed_line
.
rpartition
(
' '
)
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
processed_lines
[
line_index
]
=
f
'
{
prefix
}
{
masked_alignment
}
'
# Clear raw_alignments.
unprocessed_lines
=
{}
elif
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
unprocessed_lines
[
i
]
=
line
else
:
processed_lines
[
i
]
=
line
return
'
\n
'
.
join
((
processed_lines
[
i
]
for
i
in
range
(
len
(
processed_lines
))))
def
deduplicate_stockholm_msa
(
stockholm_msa
:
str
)
->
str
:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict
=
collections
.
defaultdict
(
str
)
# First we must extract all sequences from the MSA.
for
line
in
stockholm_msa
.
splitlines
():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
line
=
line
.
strip
()
seqname
,
alignment
=
line
.
split
()
sequence_dict
[
seqname
]
+=
alignment
seen_sequences
=
set
()
seqnames
=
set
()
# First alignment is the query.
query_align
=
next
(
iter
(
sequence_dict
.
values
()))
mask
=
[
c
!=
'-'
for
c
in
query_align
]
# Mask is False for insertions.
for
seqname
,
alignment
in
sequence_dict
.
items
():
# Apply mask to remove all insertions from the string.
masked_alignment
=
''
.
join
(
itertools
.
compress
(
alignment
,
mask
))
if
masked_alignment
in
seen_sequences
:
continue
else
:
seen_sequences
.
add
(
masked_alignment
)
seqnames
.
add
(
seqname
)
filtered_lines
=
[]
for
line
in
stockholm_msa
.
splitlines
():
if
_keep_line
(
line
,
seqnames
):
filtered_lines
.
append
(
line
)
return
'
\n
'
.
join
(
filtered_lines
)
+
'
\n
'
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
...
...
@@ -280,7 +441,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
%
(
detailed_lines
,
detailed_lines
[
2
])
)
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
(
_
,
_
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
_
)
=
[
float
(
x
)
for
x
in
match
.
groups
()
]
...
...
@@ -388,3 +549,115 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
target_name
=
fields
[
0
]
e_values
[
target_name
]
=
float
(
e_value
)
return
e_values
def
_get_indices
(
sequence
:
str
,
start
:
int
)
->
List
[
int
]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices
=
[]
counter
=
start
for
symbol
in
sequence
:
# Skip gaps but add a placeholder so that the alignment is preserved.
if
symbol
==
'-'
:
indices
.
append
(
-
1
)
# Skip deleted residues, but increase the counter.
elif
symbol
.
islower
():
counter
+=
1
# Normal aligned residue. Increase the counter and append to indices.
else
:
indices
.
append
(
counter
)
counter
+=
1
return
indices
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
HitMetadata
:
pdb_id
:
str
chain
:
str
start
:
int
end
:
int
length
:
int
text
:
str
def
_parse_hmmsearch_description
(
description
:
str
)
->
HitMetadata
:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match
=
re
.
match
(
r
'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$'
,
description
.
strip
())
if
not
match
:
raise
ValueError
(
f
'Could not parse description: "
{
description
}
".'
)
return
HitMetadata
(
pdb_id
=
match
[
1
],
chain
=
match
[
2
],
start
=
int
(
match
[
3
]),
end
=
int
(
match
[
4
]),
length
=
int
(
match
[
5
]),
text
=
match
[
6
]
)
def
parse_hmmsearch_a3m
(
query_sequence
:
str
,
a3m_string
:
str
,
skip_first
:
bool
=
True
)
->
Sequence
[
TemplateHit
]:
"""Parses an a3m string produced by hmmsearch.
Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.
Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m
=
list
(
zip
(
*
parse_fasta
(
a3m_string
)))
if
skip_first
:
parsed_a3m
=
parsed_a3m
[
1
:]
indices_query
=
_get_indices
(
query_sequence
,
start
=
0
)
hits
=
[]
for
i
,
(
hit_sequence
,
hit_description
)
in
enumerate
(
parsed_a3m
,
start
=
1
):
if
'mol:protein'
not
in
hit_description
:
continue
# Skip non-protein chains.
metadata
=
_parse_hmmsearch_description
(
hit_description
)
# Aligned columns are only the match states.
aligned_cols
=
sum
([
r
.
isupper
()
and
r
!=
'-'
for
r
in
hit_sequence
])
indices_hit
=
_get_indices
(
hit_sequence
,
start
=
metadata
.
start
-
1
)
hit
=
TemplateHit
(
index
=
i
,
name
=
f
'
{
metadata
.
pdb_id
}
_
{
metadata
.
chain
}
'
,
aligned_cols
=
aligned_cols
,
sum_probs
=
None
,
query
=
query_sequence
,
hit_sequence
=
hit_sequence
.
upper
(),
indices_query
=
indices_query
,
indices_hit
=
indices_hit
,
)
hits
.
append
(
hit
)
return
hits
def
parse_hmmsearch_sto
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
openfold/data/templates.py
View file @
56d5e39c
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import
abc
import
dataclasses
import
datetime
import
functools
import
glob
import
json
import
logging
...
...
@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class
PdbIdError
(
PrefilterError
):
"""An error indicating that the hit PDB ID was identical to the query."""
class
AlignRatioError
(
PrefilterError
):
"""An error indicating that the hit align ratio to the query was too small."""
...
...
@@ -204,7 +202,6 @@ def _assess_hhsearch_hit(
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
max_subsequence_ratio
:
float
=
0.95
,
...
...
@@ -218,7 +215,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
...
...
@@ -230,7 +226,6 @@ def _assess_hhsearch_hit(
Raises:
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
...
...
@@ -241,13 +236,6 @@ def _assess_hhsearch_hit(
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
date
=
release_dates
[
hit_pdb_code
.
upper
()]
raise
DateError
(
...
...
@@ -255,16 +243,19 @@ def _assess_hhsearch_hit(
f
"(
{
release_date_cutoff
}
)."
)
if
query_pdb_code
is
not
None
:
if
query_pdb_code
.
lower
()
==
hit_pdb_code
.
lower
():
raise
PdbIdError
(
"PDB code identical to Query PDB code."
)
if
align_ratio
<=
min_align_ratio
:
raise
AlignRatioError
(
"Proportion of residues aligned to query too small. "
f
"Align ratio:
{
align_ratio
}
."
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
duplicate
:
raise
DuplicateError
(
"Template is an exact subsequence of query with large "
...
...
@@ -424,9 +415,10 @@ def _realign_pdb_template_to_query(
)
try
:
(
old_aligned_template
,
new_aligned_template
),
_
=
parsers
.
parse_a3m
(
parsed_a3m
=
parsers
.
parse_a3m
(
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
)
old_aligned_template
,
new_aligned_template
=
parsed_a3m
.
sequences
except
Exception
as
e
:
raise
QueryToTemplateAlignError
(
"Could not align old template %s to template %s (%s_%s). Error: %s"
...
...
@@ -768,7 +760,6 @@ class SingleHitResult:
def
_prefilter_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
...
...
@@ -789,17 +780,14 @@ def _prefilter_hit(
hit
=
hit
,
hit_pdb_code
=
hit_pdb_code
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
release_dates
=
release_dates
,
release_date_cutoff
=
max_template_date
,
)
except
PrefilterError
as
e
:
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
):
logging
.
info
(
msg
)
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
# In strict mode we treat some prefilter cases as errors.
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
...
...
@@ -808,9 +796,16 @@ def _prefilter_hit(
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
_read_file
(
path
):
with
open
(
path
,
'r'
)
as
f
:
file_data
=
f
.
read
()
return
file_data
def
_process_single_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
...
...
@@ -847,9 +842,9 @@ def _process_single_hit(
query_sequence
,
template_sequence
,
)
# Fail if we can't find the mmCIF file.
with
open
(
cif_path
,
"r"
)
as
cif_file
:
cif_string
=
cif_file
.
read
()
cif_string
=
_read_file
(
cif_path
)
parsing_result
=
mmcif_parsing
.
parse
(
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
...
...
@@ -882,7 +877,11 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
if
hit
.
sum_probs
is
None
:
features
[
"template_sum_probs"
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
...
...
@@ -903,7 +902,7 @@ def _process_single_hit(
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
sum_probs
if
hit
.
sum_probs
else
0.
,
hit
.
index
,
str
(
e
),
parsing_result
.
errors
,
...
...
@@ -920,7 +919,7 @@ def _process_single_hit(
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
sum_probs
if
hit
.
sum_probs
else
0.
,
hit
.
index
,
str
(
e
),
parsing_result
.
errors
,
...
...
@@ -986,8 +985,8 @@ class TemplateSearchResult:
warnings
:
Sequence
[
str
]
class
TemplateHitFeaturizer
:
"""A class for turning
hhr hits to
template features."""
class
TemplateHitFeaturizer
(
abc
.
ABC
)
:
"""A
n abstract base
class for turning template
hits to
features."""
def
__init__
(
self
,
mmcif_dir
:
str
,
...
...
@@ -1036,7 +1035,7 @@ class TemplateHitFeaturizer:
raise
ValueError
(
"max_template_date must be set and have format YYYY-MM-DD."
)
self
.
max_hits
=
max_hits
self
.
_
max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
...
...
@@ -1059,31 +1058,29 @@ class TemplateHitFeaturizer:
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
@
abc
.
abstractmethod
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
""" Computes the templates for a given query sequence """
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Searching for template for: %s"
,
query_
pdb_cod
e
)
logging
.
info
(
"Searching for template for: %s"
,
query_
sequenc
e
)
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
# Always use a max_template_date. Set to query_release_date minus 60 days
# if that's earlier.
template_cutoff_date
=
self
.
_max_template_date
if
query_release_date
:
delta
=
datetime
.
timedelta
(
days
=
60
)
if
query_release_date
-
delta
<
template_cutoff_date
:
template_cutoff_date
=
query_release_date
-
delta
assert
template_cutoff_date
<
query_release_date
assert
template_cutoff_date
<=
self
.
_max_template_date
num_hits
=
0
already_seen
=
set
()
errors
=
[]
warnings
=
[]
...
...
@@ -1091,9 +1088,8 @@ class TemplateHitFeaturizer:
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
...
...
@@ -1119,17 +1115,16 @@ class TemplateHitFeaturizer:
for
i
in
idx
:
# We got all the templates we wanted, stop processing hits.
if
num_hits
>=
self
.
max_hits
:
if
len
(
already_seen
)
>=
self
.
_
max_hits
:
break
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
...
...
@@ -1153,22 +1148,152 @@ class TemplateHitFeaturizer:
result
.
warning
,
)
else
:
# Increment the hit counter, since we got features out of this hit.
num_hits
+=
1
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
for
name
in
template_features
:
if
num_hits
>
0
:
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
# Make sure the feature has correct dtype even if empty.
template_features
[
name
]
=
np
.
array
(
[],
dtype
=
TEMPLATE_FEATURES
[
name
]
)
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
logging
.
info
(
"Searching for template for: %s"
,
query_sequence
)
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
already_seen
=
set
()
errors
=
[]
warnings
=
[]
# DISCREPANCY: This filtering scheme that saves time
filtered
=
[]
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
)
if
prefilter_result
.
error
:
errors
.
append
(
prefilter_result
.
error
)
if
prefilter_result
.
warning
:
warnings
.
append
(
prefilter_result
.
warning
)
if
prefilter_result
.
valid
:
filtered
.
append
(
hit
)
filtered
=
list
(
sorted
(
filtered
,
key
=
lambda
x
:
x
.
sum_probs
if
x
.
sum_probs
else
0.
,
reverse
=
True
)
)
idx
=
list
(
range
(
len
(
filtered
)))
if
(
self
.
_shuffle_top_k_prefiltered
):
stk
=
self
.
_shuffle_top_k_prefiltered
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
for
i
in
idx
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
debug
(
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
,
)
openfold/data/tools/hhblits.py
View file @
56d5e39c
...
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
os
import
subprocess
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
from
openfold.data.tools
import
utils
...
...
@@ -99,9 +99,9 @@ class HHBlits:
self
.
p
=
p
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
def
query
(
self
,
input_fasta_path
:
str
)
->
List
[
Mapping
[
str
,
Any
]
]
:
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
db_cmd
=
[]
...
...
@@ -172,4 +172,4 @@ class HHBlits:
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
,
)
return
raw_output
return
[
raw_output
]
openfold/data/tools/hhsearch.py
View file @
56d5e39c
...
...
@@ -18,8 +18,9 @@ import glob
import
logging
import
os
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
,
Optional
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
...
...
@@ -62,11 +63,20 @@ class HHSearch:
f
"Could not find HHsearch database
{
database_path
}
"
)
def
query
(
self
,
a3m
:
str
)
->
str
:
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
hhr_path
=
os
.
path
.
join
(
output_dir
,
"hhsearch_output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
...
...
@@ -104,3 +114,12 @@ class HHSearch:
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
return
hhr
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool"""
del
input_sequence
# Used by hmmsearch but not needed for hhsearch
return
parsers
.
parse_hhr
(
output_string
)
openfold/data/tools/hmmbuild.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
openfold.data.tools
import
utils
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
openfold/data/tools/hmmsearch.py
0 → 100644
View file @
56d5e39c
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
openfold.data
import
parsers
from
openfold.data.tools
import
hmmbuild
from
openfold.data.tools
import
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
'Could not find hmmsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'sto'
@
property
def
input_format
(
self
)
->
str
:
return
'sto'
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
'hmm_output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
])
logging
.
info
(
'Launching sub-process %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
input_sequence
,
)
return
template_hits
openfold/data/tools/jackhmmer.py
View file @
56d5e39c
...
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
...
...
@@ -93,10 +94,13 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
...
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
if
(
max_sequences
is
None
):
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
sto
=
sto
,
...
...
@@ -180,10 +187,16 @@ class Jackhmmer:
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
...
@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
return
chunked_output
openfold/data/tools/kalign.py
View file @
56d5e39c
...
...
@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment