Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
57f869d6
Commit
57f869d6
authored
Mar 09, 2022
by
Gustaf Ahdritz
Browse files
Continue work on AlphaFold-Multimer
parent
100485dd
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
986 additions
and
96 deletions
+986
-96
openfold/config.py
openfold/config.py
+10
-0
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+280
-4
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+14
-0
openfold/data/templates.py
openfold/data/templates.py
+115
-25
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+4
-4
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+18
-1
openfold/data/tools/hmmbuild.py
openfold/data/tools/hmmbuild.py
+137
-0
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+134
-0
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+28
-7
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+69
-52
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+176
-2
No files found.
openfold/config.py
View file @
57f869d6
...
@@ -74,6 +74,8 @@ def model_config(name, train=False, low_prec=False):
...
@@ -74,6 +74,8 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
c
.
model
.
update
(
multimer_model_config_update
)
else
:
else
:
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
...
@@ -493,3 +495,11 @@ config = mlc.ConfigDict(
...
@@ -493,3 +495,11 @@ config = mlc.ConfigDict(
"ema"
:
{
"decay"
:
0.999
},
"ema"
:
{
"decay"
:
0.999
},
}
}
)
)
multimer_model_config_update
=
mlc
.
ConfigDict
(
"relative_encoding"
:
{
"enabled"
:
True
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
}
)
openfold/data/data_pipeline.py
View file @
57f869d6
...
@@ -25,6 +25,7 @@ from openfold.data import (
...
@@ -25,6 +25,7 @@ from openfold.data import (
parsers
,
parsers
,
mmcif_parsing
,
mmcif_parsing
,
msa_identifiers
,
msa_identifiers
,
msa_pairing
,
)
)
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
...
@@ -277,11 +278,13 @@ class AlignmentRunner:
...
@@ -277,11 +278,13 @@ class AlignmentRunner:
mgnify_database_path
:
Optional
[
str
]
=
None
,
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
):
"""
"""
Args:
Args:
...
@@ -320,6 +323,7 @@ class AlignmentRunner:
...
@@ -320,6 +323,7 @@ class AlignmentRunner:
uniref90_database_path
,
uniref90_database_path
,
mgnify_database_path
,
mgnify_database_path
,
bfd_database_path
if
use_small_bfd
else
None
,
bfd_database_path
if
use_small_bfd
else
None
,
uniprot_database_path
,
],
],
},
},
"hhblits"
:
{
"hhblits"
:
{
...
@@ -339,6 +343,7 @@ class AlignmentRunner:
...
@@ -339,6 +343,7 @@ class AlignmentRunner:
self
.
uniref_max_hits
=
uniref_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniprot_max_hits
=
uniprot_max_hits
self
.
use_small_bfd
=
use_small_bfd
self
.
use_small_bfd
=
use_small_bfd
if
(
no_cpus
is
None
):
if
(
no_cpus
is
None
):
...
@@ -381,6 +386,13 @@ class AlignmentRunner:
...
@@ -381,6 +386,13 @@ class AlignmentRunner:
n_cpu
=
no_cpus
,
n_cpu
=
no_cpus
,
)
)
self
.
_uniprot_msa_runner
=
None
if
(
uniprot_database_path
is
not
None
):
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
)
if
(
template_searcher
is
not
None
and
if
(
template_searcher
is
not
None
and
self
.
jackhmmer_uniref90_runner
is
None
self
.
jackhmmer_uniref90_runner
is
None
):
):
...
@@ -456,6 +468,148 @@ class AlignmentRunner:
...
@@ -456,6 +468,148 @@ class AlignmentRunner:
msa_format
=
"a3m"
,
msa_format
=
"a3m"
,
)
)
if
(
self
.
jackhmmer_uniprot_runner
is
not
None
):
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
result
=
run_msa_tool
(
self
.
jackhmmer_uniprot_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
uniprot_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
_FastaChain
:
sequence
:
str
description
:
str
def
_make_chain_id_map
(
*
,
sequences
:
Sequence
[
str
],
descriptions
:
Sequence
[
str
],
)
->
Mapping
[
str
,
_FastaChain
]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if
len
(
sequences
)
!=
len
(
descriptions
):
raise
ValueError
(
'sequences and descriptions must have equal length. '
f
'Got
{
len
(
sequences
)
}
!=
{
len
(
descriptions
)
}
.'
)
if
len
(
sequences
)
>
protein
.
PDB_MAX_CHAINS
:
raise
ValueError
(
'Cannot process more chains than the PDB format supports. '
f
'Got
{
len
(
sequences
)
}
chains.'
)
chain_id_map
=
{}
for
chain_id
,
sequence
,
description
in
zip
(
protein
.
PDB_CHAIN_IDS
,
sequences
,
descriptions
):
chain_id_map
[
chain_id
]
=
_FastaChain
(
sequence
=
sequence
,
description
=
description
)
return
chain_id_map
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
convert_monomer_features
(
monomer_features
:
FeatureDict
,
chain_id
:
str
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
for
feature_name
,
feature
in
monomer_features
.
items
():
if
feature_name
in
unnecessary_leading_dim_feats
:
# asarray ensures it's a np.ndarray.
feature
=
np
.
asarray
(
feature
[
0
],
dtype
=
feature
.
dtype
)
elif
feature_name
==
'aatype'
:
# The multimer model performs the one-hot operation itself.
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
elif
feature_name
==
'template_aatype'
:
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature
=
np
.
take
(
new_order_list
,
feature
.
astype
(
np
.
int32
),
axis
=
0
)
elif
feature_name
==
'template_all_atom_masks'
:
feature_name
=
'template_all_atom_mask'
converted
[
feature_name
]
=
feature
return
converted
def
int_id_to_str_id
(
num
:
int
)
->
str
:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if
num
<=
0
:
raise
ValueError
(
f
'Only positive integers allowed, got
{
num
}
.'
)
num
=
num
-
1
# 1-based indexing.
output
=
[]
while
num
>=
0
:
output
.
append
(
chr
(
num
%
26
+
ord
(
'A'
)))
num
=
num
//
26
-
1
return
''
.
join
(
output
)
def
add_assembly_features
(
all_chain_features
:
MutableMapping
[
str
,
FeatureDict
],
)
->
MutableMapping
[
str
,
FeatureDict
]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id
=
{}
grouped_chains
=
collections
.
defaultdict
(
list
)
for
chain_id
,
chain_features
in
all_chain_features
.
items
():
seq
=
str
(
chain_features
[
'sequence'
])
if
seq
not
in
seq_to_entity_id
:
seq_to_entity_id
[
seq
]
=
len
(
seq_to_entity_id
)
+
1
grouped_chains
[
seq_to_entity_id
[
seq
]].
append
(
chain_features
)
new_all_chain_features
=
{}
chain_id
=
1
for
entity_id
,
group_chain_features
in
grouped_chains
.
items
():
for
sym_id
,
chain_features
in
enumerate
(
group_chain_features
,
start
=
1
):
new_all_chain_features
[
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
seq_length
=
chain_features
[
'seq_length'
]
chain_features
[
'asym_id'
]
=
chain_id
*
np
.
ones
(
seq_length
)
chain_features
[
'sym_id'
]
=
sym_id
*
np
.
ones
(
seq_length
)
chain_features
[
'entity_id'
]
=
entity_id
*
np
.
ones
(
seq_length
)
chain_id
+=
1
return
new_all_chain_features
def
pad_msa
(
np_example
,
min_num_seq
):
np_example
=
dict
(
np_example
)
num_seq
=
np_example
[
'msa'
].
shape
[
0
]
if
num_seq
<
min_num_seq
:
for
feat
in
(
'msa'
,
'deletion_matrix'
,
'bert_mask'
,
'msa_mask'
):
np_example
[
feat
]
=
np
.
pad
(
np_example
[
feat
],
((
0
,
min_num_seq
-
num_seq
),
(
0
,
0
)))
np_example
[
'cluster_bias_mask'
]
=
np
.
pad
(
np_example
[
'cluster_bias_mask'
],
((
0
,
min_num_seq
-
num_seq
),))
return
np_example
class
DataPipeline
:
class
DataPipeline
:
"""Assembles input features."""
"""Assembles input features."""
...
@@ -579,10 +733,9 @@ class DataPipeline:
...
@@ -579,10 +733,9 @@ class DataPipeline:
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
])
msa_features
=
make_msa_features
(
msa_objects
=
[
Msa
(
m
,
d
)
for
m
,
d
in
zip
(
msas
,
deletion_matrices
)]
msas
=
msas
,
deletion_matrices
=
deletion_matrices
,
msa_features
=
make_msa_features
(
msa_objects
)
)
return
msa_features
return
msa_features
...
@@ -722,3 +875,126 @@ class DataPipeline:
...
@@ -722,3 +875,126 @@ class DataPipeline:
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
class
DataPipelineMultimer
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
monomer_data_pipeline
:
DataPipeline
,
jackhmmer_binary_path
:
str
,
uniprot_database_path
:
str
,
max_uniprot_hits
:
int
=
50000
,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
def
_process_single_chain
(
self
,
chain_id
:
str
,
sequence
:
str
,
description
:
str
,
msa_output_dir
:
str
,
is_homomer_or_monomer
:
bool
)
->
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>chain_
{
chain_id
}
\n
{
sequence
}
\n
'
chain_msa_output_dir
=
os
.
path
.
join
(
msa_output_dir
,
chain_id
)
if
not
os
.
path
.
exists
(
chain_msa_output_dir
):
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
input_fasta_path
=
chain_fasta_path
,
alignment_dir
=
chain_msa_output_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_fasta_path
,
chain_msa_output_dir
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
def
_all_seq_msa_features
(
self
,
input_fasta_path
,
msa_output_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path
=
os
.
path
.
join
(
msa_output_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_uniprot_accession_identifiers'
,
'msa_species_identifiers'
,
)
feats
=
{
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
if
k
in
valid_feats
}
return
feats
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
,
is_prokaryote
:
bool
=
False
)
->
FeatureDict
:
"""Runs alignment tools on the input sequences and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
chain_id_map
=
_make_chain_id_map
(
sequences
=
input_seqs
,
descriptions
=
input_descs
)
chain_id_map_path
=
os
.
path
.
join
(
msa_output_dir
,
'chain_id_map.json'
)
with
open
(
chain_id_map_path
,
'w'
)
as
f
:
chain_id_map_dict
=
{
chain_id
:
dataclasses
.
asdict
(
fasta_chain
)
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
()
}
json
.
dump
(
chain_id_map_dict
,
f
,
indent
=
4
,
sort_keys
=
True
)
all_chain_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
():
if
fasta_chain
.
sequence
in
sequence_features
:
all_chain_features
[
chain_id
]
=
copy
.
deepcopy
(
sequence_features
[
fasta_chain
.
sequence
])
continue
chain_features
=
self
.
_process_single_chain
(
chain_id
=
chain_id
,
sequence
=
fasta_chain
.
sequence
,
description
=
fasta_chain
.
description
,
msa_output_dir
=
msa_output_dir
,
is_homomer_or_monomer
=
is_homomer_or_monomer
)
chain_features
=
convert_monomer_features
(
chain_features
,
chain_id
=
chain_id
)
all_chain_features
[
chain_id
]
=
chain_features
sequence_features
[
fasta_chain
.
sequence
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
is_prokaryote
=
is_prokaryote
,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
openfold/data/mmcif_parsing.py
View file @
57f869d6
...
@@ -476,6 +476,20 @@ def get_atom_coords(
...
@@ -476,6 +476,20 @@ def get_atom_coords(
pos
[
residue_constants
.
atom_order
[
"SD"
]]
=
[
x
,
y
,
z
]
pos
[
residue_constants
.
atom_order
[
"SD"
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
"SD"
]]
=
1.0
mask
[
residue_constants
.
atom_order
[
"SD"
]]
=
1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd
=
residue_constants
.
atom_order
[
'CD'
]
nh1
=
residue_constants
.
atom_order
[
'NH1'
]
nh2
=
residue_constants
.
atom_order
[
'NH2'
]
if
(
res
.
get_resname
()
==
'ARG'
and
all
(
mask
[
atom_index
]
for
atom_index
in
(
cd
,
nh1
,
nh2
))
and
(
np
.
linalg
.
norm
(
pos
[
nh1
]
-
pos
[
cd
])
>
np
.
linalg
.
norm
(
pos
[
nh2
]
-
pos
[
cd
]))
):
pos
[
nh1
],
pos
[
nh2
]
=
pos
[
nh2
].
copy
(),
pos
[
nh1
].
copy
()
mask
[
nh1
],
mask
[
nh2
]
=
mask
[
nh2
].
copy
(),
mask
[
nh1
].
copy
()
all_atom_positions
[
res_index
]
=
pos
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
all_atom_mask
[
res_index
]
=
mask
...
...
openfold/data/templates.py
View file @
57f869d6
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
# limitations under the License.
# limitations under the License.
"""Functions for getting templates and calculating template features."""
"""Functions for getting templates and calculating template features."""
import
abc
import
dataclasses
import
dataclasses
import
datetime
import
datetime
import
functools
import
glob
import
glob
import
json
import
json
import
logging
import
logging
...
@@ -65,10 +67,6 @@ class DateError(PrefilterError):
...
@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
"""An error indicating that the hit date was after the max allowed date."""
class
PdbIdError
(
PrefilterError
):
"""An error indicating that the hit PDB ID was identical to the query."""
class
AlignRatioError
(
PrefilterError
):
class
AlignRatioError
(
PrefilterError
):
"""An error indicating that the hit align ratio to the query was too small."""
"""An error indicating that the hit align ratio to the query was too small."""
...
@@ -188,7 +186,6 @@ def _assess_hhsearch_hit(
...
@@ -188,7 +186,6 @@ def _assess_hhsearch_hit(
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
release_date_cutoff
:
datetime
.
datetime
,
max_subsequence_ratio
:
float
=
0.95
,
max_subsequence_ratio
:
float
=
0.95
,
...
@@ -202,7 +199,6 @@ def _assess_hhsearch_hit(
...
@@ -202,7 +199,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might
different from the value in the actual hit since the original pdb might
have become obsolete.
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
release_dates: Dictionary mapping pdb codes to their structure release
dates.
dates.
release_date_cutoff: Max release date that is valid for this query.
release_date_cutoff: Max release date that is valid for this query.
...
@@ -214,7 +210,6 @@ def _assess_hhsearch_hit(
...
@@ -214,7 +210,6 @@ def _assess_hhsearch_hit(
Raises:
Raises:
DateError: If the hit date was after the max allowed date.
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
LengthError: If the hit was too short.
...
@@ -239,10 +234,6 @@ def _assess_hhsearch_hit(
...
@@ -239,10 +234,6 @@ def _assess_hhsearch_hit(
f
"(
{
release_date_cutoff
}
)."
f
"(
{
release_date_cutoff
}
)."
)
)
if
query_pdb_code
is
not
None
:
if
query_pdb_code
.
lower
()
==
hit_pdb_code
.
lower
():
raise
PdbIdError
(
"PDB code identical to Query PDB code."
)
if
align_ratio
<=
min_align_ratio
:
if
align_ratio
<=
min_align_ratio
:
raise
AlignRatioError
(
raise
AlignRatioError
(
"Proportion of residues aligned to query too small. "
"Proportion of residues aligned to query too small. "
...
@@ -408,9 +399,10 @@ def _realign_pdb_template_to_query(
...
@@ -408,9 +399,10 @@ def _realign_pdb_template_to_query(
)
)
try
:
try
:
(
old_aligned_template
,
new_aligned_template
),
_
=
parsers
.
parse_a3m
(
parsed_a3m
=
parsers
.
parse_a3m
(
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
)
)
old_aligned_template
,
new_aligned_template
=
parsed_a3m
.
sequences
except
Exception
as
e
:
except
Exception
as
e
:
raise
QueryToTemplateAlignError
(
raise
QueryToTemplateAlignError
(
"Could not align old template %s to template %s (%s_%s). Error: %s"
"Could not align old template %s to template %s (%s_%s). Error: %s"
...
@@ -752,7 +744,6 @@ class SingleHitResult:
...
@@ -752,7 +744,6 @@ class SingleHitResult:
def
_prefilter_hit
(
def
_prefilter_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
...
@@ -773,7 +764,6 @@ def _prefilter_hit(
...
@@ -773,7 +764,6 @@ def _prefilter_hit(
hit
=
hit
,
hit
=
hit
,
hit_pdb_code
=
hit_pdb_code
,
hit_pdb_code
=
hit_pdb_code
,
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
release_dates
=
release_dates
,
release_dates
=
release_dates
,
release_date_cutoff
=
max_template_date
,
release_date_cutoff
=
max_template_date
,
)
)
...
@@ -781,9 +771,7 @@ def _prefilter_hit(
...
@@ -781,9 +771,7 @@ def _prefilter_hit(
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
if
strict_error_check
and
isinstance
(
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
):
# In strict mode we treat some prefilter cases as errors.
# In strict mode we treat some prefilter cases as errors.
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
...
@@ -792,9 +780,16 @@ def _prefilter_hit(
...
@@ -792,9 +780,16 @@ def _prefilter_hit(
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
_read_file
(
path
):
with
open
(
path
,
'r'
)
as
f
:
file_data
=
f
.
read
()
return
file_data
def
_process_single_hit
(
def
_process_single_hit
(
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
max_template_date
:
datetime
.
datetime
,
...
@@ -832,8 +827,7 @@ def _process_single_hit(
...
@@ -832,8 +827,7 @@ def _process_single_hit(
template_sequence
,
template_sequence
,
)
)
# Fail if we can't find the mmCIF file.
# Fail if we can't find the mmCIF file.
with
open
(
cif_path
,
"r"
)
as
cif_file
:
cif_string
=
_read_file
(
cif_path
)
cif_string
=
cif_file
.
read
()
parsing_result
=
mmcif_parsing
.
parse
(
parsing_result
=
mmcif_parsing
.
parse
(
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
...
@@ -866,7 +860,11 @@ def _process_single_hit(
...
@@ -866,7 +860,11 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
_zero_center_positions
=
_zero_center_positions
,
)
)
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
if
hit
.
sum_probs
is
None
:
features
[
"template_sum_probs"
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# mmCIF file, but the template features for the chain we want were still
...
@@ -920,8 +918,8 @@ class TemplateSearchResult:
...
@@ -920,8 +918,8 @@ class TemplateSearchResult:
warnings
:
Sequence
[
str
]
warnings
:
Sequence
[
str
]
class
TemplateHitFeaturizer
:
class
TemplateHitFeaturizer
(
abc
.
ABC
)
:
"""A class for turning
hhr hits to
template features."""
"""A
n abstract base
class for turning template
hits to
features."""
def
__init__
(
def
__init__
(
self
,
self
,
mmcif_dir
:
str
,
mmcif_dir
:
str
,
...
@@ -993,10 +991,18 @@ class TemplateHitFeaturizer:
...
@@ -993,10 +991,18 @@ class TemplateHitFeaturizer:
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
self
.
_zero_center_positions
=
_zero_center_positions
@
abc
.
abstractmethod
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
def
get_templates
(
self
,
self
,
query_sequence
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
...
@@ -1025,7 +1031,6 @@ class TemplateHitFeaturizer:
...
@@ -1025,7 +1031,6 @@ class TemplateHitFeaturizer:
for
hit
in
hits
:
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
hit
=
hit
,
max_template_date
=
template_cutoff_date
,
max_template_date
=
template_cutoff_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
...
@@ -1105,3 +1110,88 @@ class TemplateHitFeaturizer:
...
@@ -1105,3 +1110,88 @@ class TemplateHitFeaturizer:
return
TemplateSearchResult
(
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
already_seen
=
set
()
errors
=
[]
warnings
=
[]
if
not
hits
or
hits
[
0
].
sum_probs
is
None
:
sorted_hits
=
hits
else
:
sorted_hits
=
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
)
for
hit
in
sorted_hits
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
debug
(
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
,
)
openfold/data/tools/hhblits.py
View file @
57f869d6
...
@@ -18,7 +18,7 @@ import glob
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -99,9 +99,9 @@ class HHBlits:
...
@@ -99,9 +99,9 @@ class HHBlits:
self
.
p
=
p
self
.
p
=
p
self
.
z
=
z
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
def
query
(
self
,
input_fasta_path
:
str
)
->
List
[
Mapping
[
str
,
Any
]
]
:
"""Queries the database using HHblits."""
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
db_cmd
=
[]
db_cmd
=
[]
...
@@ -172,4 +172,4 @@ class HHBlits:
...
@@ -172,4 +172,4 @@ class HHBlits:
n_iter
=
self
.
n_iter
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
,
e_value
=
self
.
e_value
,
)
)
return
raw_output
return
[
raw_output
]
openfold/data/tools/hhsearch.py
View file @
57f869d6
...
@@ -20,6 +20,7 @@ import os
...
@@ -20,6 +20,7 @@ import os
import
subprocess
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -62,9 +63,17 @@ class HHSearch:
...
@@ -62,9 +63,17 @@ class HHSearch:
f
"Could not find HHsearch database
{
database_path
}
"
f
"Could not find HHsearch database
{
database_path
}
"
)
)
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
)
->
str
:
def
query
(
self
,
a3m
:
str
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
with
open
(
input_path
,
"w"
)
as
f
:
...
@@ -104,3 +113,11 @@ class HHSearch:
...
@@ -104,3 +113,11 @@ class HHSearch:
with
open
(
hhr_path
)
as
f
:
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
hhr
=
f
.
read
()
return
hhr
return
hhr
def
get_template_hits
(
self
,
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool"""
del
input_sequence
# Used by hmmsearch but not needed for hhsearch
return
parsers
.
parse_hhr
(
output_string
)
openfold/data/tools/hmmbuild.py
0 → 100644
View file @
57f869d6
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
openfold.data.tools
import
utils
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
openfold/data/tools/hmmsearch.py
0 → 100644
View file @
57f869d6
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
openfold.data
import
parsers
from
openfold.data.tools
import
hmmbuild
from
openfold.data.tools
import
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
'Could not find hmmsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'sto'
@
property
def
input_format
(
self
)
->
str
:
return
'sto'
def
query
(
self
,
msa_sto
:
str
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
)
def
query_with_hmm
(
self
,
hmm
:
str
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
out_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
])
logging
.
info
(
'Launching sub-process %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
def
get_template_hits
(
self
,
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
parsers
.
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parsers
.
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
openfold/data/tools/jackhmmer.py
View file @
57f869d6
...
@@ -23,6 +23,7 @@ import subprocess
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
urllib
import
request
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -93,10 +94,13 @@ class Jackhmmer:
...
@@ -93,10 +94,13 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# The F1/F2/F3 are the expected proportion to pass each of the filtering
...
@@ -167,8 +171,11 @@ class Jackhmmer:
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
tbl
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
if
(
max_sequences
is
None
):
sto
=
f
.
read
()
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
raw_output
=
dict
(
sto
=
sto
,
sto
=
sto
,
...
@@ -180,10 +187,16 @@ class Jackhmmer:
...
@@ -180,10 +187,16 @@ class Jackhmmer:
return
raw_output
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
@@ -217,12 +230,20 @@ class Jackhmmer:
...
@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
# Run Jackhmmer with the chunk
future
.
result
()
future
.
result
()
chunked_output
.
append
(
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
)
# Remove the local copy of the chunk
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
future
=
next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
self
.
streaming_callback
(
i
)
return
chunked_output
return
chunked_output
openfold/data/tools/kalign.py
View file @
57f869d6
...
@@ -72,7 +72,7 @@ class Kalign:
...
@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
)
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
...
...
openfold/model/structure_module.py
View file @
57f869d6
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
...
@@ -151,6 +151,40 @@ class AngleResnet(nn.Module):
...
@@ -151,6 +151,40 @@ class AngleResnet(nn.Module):
return
unnormalized_s
,
s
return
unnormalized_s
,
s
class
PointProjection
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
:
int
,
num_points
:
int
,
no_heads
:
int
return_local_points
:
bool
=
False
,
):
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
linear
=
Linear
(
c_hidden
,
3
*
num_points
)
def
forward
(
self
,
activations
:
torch
.
Tensor
,
rigids
:
Rigid3Array
,
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
]]:
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
points_local
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
,
)
points_local
=
torch
.
split
(
points_local
,
3
,
dim
=-
1
)
points_local
=
Vec3Array
(
*
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
if
(
self
.
return_local_points
):
return
points_global
,
points_local
return
points_global
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
"""
"""
Implements Algorithm 22.
Implements Algorithm 22.
...
@@ -200,13 +234,23 @@ class InvariantPointAttention(nn.Module):
...
@@ -200,13 +234,23 @@ class InvariantPointAttention(nn.Module):
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
)
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
)
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
hpq
=
self
.
no_heads
*
self
.
no_qk_points
*
3
self
.
linear_q_points
=
PointProjection
(
self
.
linear_q_points
=
Linear
(
self
.
c_s
,
hpq
)
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
)
hpkv
=
self
.
no_heads
*
(
self
.
no_qk_points
+
self
.
no_v_points
)
*
3
self
.
linear_k_points
=
PointProjection
(
self
.
linear_kv_points
=
Linear
(
self
.
c_s
,
hpkv
)
self
.
c_s
,
self
.
no_qk_points
self
.
no_heads
,
)
hpv
=
self
.
no_heads
*
self
.
no_v_points
*
3
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
no_v_points
self
.
no_heads
,
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
...
@@ -257,35 +301,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -257,35 +301,14 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
# [*, N_res, H
*
P_q
* 3
]
# [*, N_res, H
,
P_q
k
]
q_pts
=
self
.
linear_q_points
(
s
)
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H, P_qk, 3]
# [*, N_res, H * P_q, 3]
k_pts
=
self
.
linear_k_points
(
s
,
r
)
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
r
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_q, 3]
# [*, N_res, H, P_v, 3]
q_pts
=
q_pts
.
view
(
v_pts
=
self
.
linear_v_points
(
s
,
r
)
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts
=
self
.
linear_kv_points
(
s
)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
r
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
kv_pts
,
[
self
.
no_qk_points
,
self
.
no_v_points
],
dim
=-
2
)
##########################
##########################
# Compute attention scores
# Compute attention scores
...
@@ -302,8 +325,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -302,8 +325,8 @@ class InvariantPointAttention(nn.Module):
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
pt_att
*
*
2
pt_att
=
pt_att
*
pt_att
+
self
.
eps
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
...
@@ -340,26 +363,20 @@ class InvariantPointAttention(nn.Module):
...
@@ -340,26 +363,20 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# happens in float32.
# [*, H, 3, N_res, P_v]
# [*, N_res, H, P_v]
o_pt
=
torch
.
sum
(
o_pt
=
v_pts
.
tensor_dot
(
(
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
)
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, P_v]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
r
[...,
None
,
None
].
apply_inverse_to_point
(
o_pt
)
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
2
)
# [*, N_res, H * P_v, 3]
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
# [*, N_res, H, C_z]
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
...
@@ -370,7 +387,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -370,7 +387,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
# [*, N_res, C_s]
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
torch
.
cat
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
)
,
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
)
)
...
...
openfold/np/residue_constants.py
View file @
57f869d6
...
@@ -24,8 +24,6 @@ from importlib import resources
...
@@ -24,8 +24,6 @@ from importlib import resources
import
numpy
as
np
import
numpy
as
np
import
tree
import
tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca
=
3.80209737096
ca_ca
=
3.80209737096
...
@@ -1309,3 +1307,179 @@ def aatype_to_str_sequence(aatype):
...
@@ -1309,3 +1307,179 @@ def aatype_to_str_sequence(aatype):
restypes_with_x
[
aatype
[
i
]]
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
for
i
in
range
(
len
(
aatype
))
])
])
### ALPHAFOLD MULTIMER STUFF ###
def
_make_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
restypes
:
residue_name
=
restype_1to3
[
residue_name
]
residue_chi_angles
=
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
np
.
array
(
chi_atom_indices
)
def
_make_renaming_matrices
():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3
=
[
restype_1to3
[
res
]
for
res
in
restypes
]
restype_3
+=
[
'UNK'
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
resname
,
swap
in
residue_atom_renaming_swaps
.
items
():
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
return
renaming_matrices
def
_make_restype_atom37_mask
():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
restypes
):
restype_name
=
restype_1to3
[
restype_letter
]
atom_names
=
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
return
restype_atom37_mask
def
_make_restype_atom14_mask
():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask
=
[]
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
return
restype_atom14_mask
def
_make_restype_atom37_to_atom14
():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
atom_types
])
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
return
restype_atom37_to_atom14
def
_make_restype_atom14_to_atom37
():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
return
restype_atom14_to_atom37
def
_make_restype_atom14_is_ambiguous
():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
for
resname
,
swap
in
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
restype_order
[
restype_3to1
[
resname
]]
atom_idx1
=
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
return
restype_atom14_is_ambiguous
def
_make_restype_rigidgroup_base_atom37_idx
():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
# 0: backbone frame
base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
# 3: 'psi-group'
base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
# 4,5,6,7: 'chi1,2,3,4-group'
for
restype
,
restype_letter
in
enumerate
(
restypes
):
resname
=
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
chi_angles_mask
[
restype
][
chi_idx
]:
atom_names
=
chi_angles_atoms
[
resname
][
chi_idx
]
base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
# Translate atom names into atom37 indices.
lookuptable
=
atom_order
.
copy
()
lookuptable
[
''
]
=
0
restype_rigidgroup_base_atom37_idx
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])(
base_atom_names
)
return
restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES
=
_make_chi_atom_indices
()
RENAMING_MATRICES
=
_make_renaming_matrices
()
RESTYPE_ATOM14_TO_ATOM37
=
_make_restype_atom14_to_atom37
()
RESTYPE_ATOM37_TO_ATOM14
=
_make_restype_atom37_to_atom14
()
RESTYPE_ATOM37_MASK
=
_make_restype_atom37_mask
()
RESTYPE_ATOM14_MASK
=
_make_restype_atom14_mask
()
RESTYPE_ATOM14_IS_AMBIGUOUS
=
_make_restype_atom14_is_ambiguous
()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
=
_make_restype_rigidgroup_base_atom37_idx
()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
RESTYPE_RIGIDGROUP_MASK
[:,
0
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:,
3
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:
20
,
4
:]
=
chi_angles_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment