Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
57f869d6
Commit
57f869d6
authored
Mar 09, 2022
by
Gustaf Ahdritz
Browse files
Continue work on AlphaFold-Multimer
parent
100485dd
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
986 additions
and
96 deletions
+986
-96
openfold/config.py
openfold/config.py
+10
-0
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+280
-4
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+14
-0
openfold/data/templates.py
openfold/data/templates.py
+115
-25
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+4
-4
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+18
-1
openfold/data/tools/hmmbuild.py
openfold/data/tools/hmmbuild.py
+137
-0
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+134
-0
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+28
-7
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+1
-1
openfold/model/structure_module.py
openfold/model/structure_module.py
+69
-52
openfold/np/residue_constants.py
openfold/np/residue_constants.py
+176
-2
No files found.
openfold/config.py
View file @
57f869d6
...
...
@@ -74,6 +74,8 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
template
.
enabled
=
False
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
c
.
model
.
update
(
multimer_model_config_update
)
else
:
raise
ValueError
(
"Invalid model name"
)
...
...
@@ -493,3 +495,11 @@ config = mlc.ConfigDict(
"ema"
:
{
"decay"
:
0.999
},
}
)
multimer_model_config_update
=
mlc
.
ConfigDict
(
"relative_encoding"
:
{
"enabled"
:
True
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
}
)
openfold/data/data_pipeline.py
View file @
57f869d6
...
...
@@ -25,6 +25,7 @@ from openfold.data import (
parsers
,
mmcif_parsing
,
msa_identifiers
,
msa_pairing
,
)
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
...
...
@@ -277,11 +278,13 @@ class AlignmentRunner:
mgnify_database_path
:
Optional
[
str
]
=
None
,
bfd_database_path
:
Optional
[
str
]
=
None
,
uniclust30_database_path
:
Optional
[
str
]
=
None
,
uniprot_database_path
:
Optional
[
str
]
=
None
,
template_searcher
:
Optional
[
TemplateSearcher
]
=
None
,
use_small_bfd
:
Optional
[
bool
]
=
None
,
no_cpus
:
Optional
[
int
]
=
None
,
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
uniprot_max_hits
:
int
=
50000
,
):
"""
Args:
...
...
@@ -320,6 +323,7 @@ class AlignmentRunner:
uniref90_database_path
,
mgnify_database_path
,
bfd_database_path
if
use_small_bfd
else
None
,
uniprot_database_path
,
],
},
"hhblits"
:
{
...
...
@@ -339,6 +343,7 @@ class AlignmentRunner:
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniprot_max_hits
=
uniprot_max_hits
self
.
use_small_bfd
=
use_small_bfd
if
(
no_cpus
is
None
):
...
...
@@ -381,6 +386,13 @@ class AlignmentRunner:
n_cpu
=
no_cpus
,
)
self
.
_uniprot_msa_runner
=
None
if
(
uniprot_database_path
is
not
None
):
self
.
jackhmmer_uniprot_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
)
if
(
template_searcher
is
not
None
and
self
.
jackhmmer_uniref90_runner
is
None
):
...
...
@@ -456,6 +468,148 @@ class AlignmentRunner:
msa_format
=
"a3m"
,
)
if
(
self
.
jackhmmer_uniprot_runner
is
not
None
):
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
result
=
run_msa_tool
(
self
.
jackhmmer_uniprot_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
uniprot_out_path
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
_FastaChain
:
sequence
:
str
description
:
str
def
_make_chain_id_map
(
*
,
sequences
:
Sequence
[
str
],
descriptions
:
Sequence
[
str
],
)
->
Mapping
[
str
,
_FastaChain
]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if
len
(
sequences
)
!=
len
(
descriptions
):
raise
ValueError
(
'sequences and descriptions must have equal length. '
f
'Got
{
len
(
sequences
)
}
!=
{
len
(
descriptions
)
}
.'
)
if
len
(
sequences
)
>
protein
.
PDB_MAX_CHAINS
:
raise
ValueError
(
'Cannot process more chains than the PDB format supports. '
f
'Got
{
len
(
sequences
)
}
chains.'
)
chain_id_map
=
{}
for
chain_id
,
sequence
,
description
in
zip
(
protein
.
PDB_CHAIN_IDS
,
sequences
,
descriptions
):
chain_id_map
[
chain_id
]
=
_FastaChain
(
sequence
=
sequence
,
description
=
description
)
return
chain_id_map
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
convert_monomer_features
(
monomer_features
:
FeatureDict
,
chain_id
:
str
)
->
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
for
feature_name
,
feature
in
monomer_features
.
items
():
if
feature_name
in
unnecessary_leading_dim_feats
:
# asarray ensures it's a np.ndarray.
feature
=
np
.
asarray
(
feature
[
0
],
dtype
=
feature
.
dtype
)
elif
feature_name
==
'aatype'
:
# The multimer model performs the one-hot operation itself.
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
elif
feature_name
==
'template_aatype'
:
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature
=
np
.
take
(
new_order_list
,
feature
.
astype
(
np
.
int32
),
axis
=
0
)
elif
feature_name
==
'template_all_atom_masks'
:
feature_name
=
'template_all_atom_mask'
converted
[
feature_name
]
=
feature
return
converted
def
int_id_to_str_id
(
num
:
int
)
->
str
:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if
num
<=
0
:
raise
ValueError
(
f
'Only positive integers allowed, got
{
num
}
.'
)
num
=
num
-
1
# 1-based indexing.
output
=
[]
while
num
>=
0
:
output
.
append
(
chr
(
num
%
26
+
ord
(
'A'
)))
num
=
num
//
26
-
1
return
''
.
join
(
output
)
def
add_assembly_features
(
all_chain_features
:
MutableMapping
[
str
,
FeatureDict
],
)
->
MutableMapping
[
str
,
FeatureDict
]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id
=
{}
grouped_chains
=
collections
.
defaultdict
(
list
)
for
chain_id
,
chain_features
in
all_chain_features
.
items
():
seq
=
str
(
chain_features
[
'sequence'
])
if
seq
not
in
seq_to_entity_id
:
seq_to_entity_id
[
seq
]
=
len
(
seq_to_entity_id
)
+
1
grouped_chains
[
seq_to_entity_id
[
seq
]].
append
(
chain_features
)
new_all_chain_features
=
{}
chain_id
=
1
for
entity_id
,
group_chain_features
in
grouped_chains
.
items
():
for
sym_id
,
chain_features
in
enumerate
(
group_chain_features
,
start
=
1
):
new_all_chain_features
[
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
seq_length
=
chain_features
[
'seq_length'
]
chain_features
[
'asym_id'
]
=
chain_id
*
np
.
ones
(
seq_length
)
chain_features
[
'sym_id'
]
=
sym_id
*
np
.
ones
(
seq_length
)
chain_features
[
'entity_id'
]
=
entity_id
*
np
.
ones
(
seq_length
)
chain_id
+=
1
return
new_all_chain_features
def
pad_msa
(
np_example
,
min_num_seq
):
np_example
=
dict
(
np_example
)
num_seq
=
np_example
[
'msa'
].
shape
[
0
]
if
num_seq
<
min_num_seq
:
for
feat
in
(
'msa'
,
'deletion_matrix'
,
'bert_mask'
,
'msa_mask'
):
np_example
[
feat
]
=
np
.
pad
(
np_example
[
feat
],
((
0
,
min_num_seq
-
num_seq
),
(
0
,
0
)))
np_example
[
'cluster_bias_mask'
]
=
np
.
pad
(
np_example
[
'cluster_bias_mask'
],
((
0
,
min_num_seq
-
num_seq
),))
return
np_example
class
DataPipeline
:
"""Assembles input features."""
...
...
@@ -579,10 +733,9 @@ class DataPipeline:
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
msa_features
=
make_msa_features
(
msas
=
msas
,
deletion_matrices
=
deletion_matrices
,
)
msa_objects
=
[
Msa
(
m
,
d
)
for
m
,
d
in
zip
(
msas
,
deletion_matrices
)]
msa_features
=
make_msa_features
(
msa_objects
)
return
msa_features
...
...
@@ -722,3 +875,126 @@ class DataPipeline:
return
{
**
core_feats
,
**
template_features
,
**
msa_features
}
class
DataPipelineMultimer
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
monomer_data_pipeline
:
DataPipeline
,
jackhmmer_binary_path
:
str
,
uniprot_database_path
:
str
,
max_uniprot_hits
:
int
=
50000
,
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
def
_process_single_chain
(
self
,
chain_id
:
str
,
sequence
:
str
,
description
:
str
,
msa_output_dir
:
str
,
is_homomer_or_monomer
:
bool
)
->
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>chain_
{
chain_id
}
\n
{
sequence
}
\n
'
chain_msa_output_dir
=
os
.
path
.
join
(
msa_output_dir
,
chain_id
)
if
not
os
.
path
.
exists
(
chain_msa_output_dir
):
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
input_fasta_path
=
chain_fasta_path
,
alignment_dir
=
chain_msa_output_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_fasta_path
,
chain_msa_output_dir
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
def
_all_seq_msa_features
(
self
,
input_fasta_path
,
msa_output_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path
=
os
.
path
.
join
(
msa_output_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_uniprot_accession_identifiers'
,
'msa_species_identifiers'
,
)
feats
=
{
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
if
k
in
valid_feats
}
return
feats
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
,
is_prokaryote
:
bool
=
False
)
->
FeatureDict
:
"""Runs alignment tools on the input sequences and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
chain_id_map
=
_make_chain_id_map
(
sequences
=
input_seqs
,
descriptions
=
input_descs
)
chain_id_map_path
=
os
.
path
.
join
(
msa_output_dir
,
'chain_id_map.json'
)
with
open
(
chain_id_map_path
,
'w'
)
as
f
:
chain_id_map_dict
=
{
chain_id
:
dataclasses
.
asdict
(
fasta_chain
)
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
()
}
json
.
dump
(
chain_id_map_dict
,
f
,
indent
=
4
,
sort_keys
=
True
)
all_chain_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
():
if
fasta_chain
.
sequence
in
sequence_features
:
all_chain_features
[
chain_id
]
=
copy
.
deepcopy
(
sequence_features
[
fasta_chain
.
sequence
])
continue
chain_features
=
self
.
_process_single_chain
(
chain_id
=
chain_id
,
sequence
=
fasta_chain
.
sequence
,
description
=
fasta_chain
.
description
,
msa_output_dir
=
msa_output_dir
,
is_homomer_or_monomer
=
is_homomer_or_monomer
)
chain_features
=
convert_monomer_features
(
chain_features
,
chain_id
=
chain_id
)
all_chain_features
[
chain_id
]
=
chain_features
sequence_features
[
fasta_chain
.
sequence
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
is_prokaryote
=
is_prokaryote
,
)
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
openfold/data/mmcif_parsing.py
View file @
57f869d6
...
...
@@ -476,6 +476,20 @@ def get_atom_coords(
pos
[
residue_constants
.
atom_order
[
"SD"
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
"SD"
]]
=
1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1
cd
=
residue_constants
.
atom_order
[
'CD'
]
nh1
=
residue_constants
.
atom_order
[
'NH1'
]
nh2
=
residue_constants
.
atom_order
[
'NH2'
]
if
(
res
.
get_resname
()
==
'ARG'
and
all
(
mask
[
atom_index
]
for
atom_index
in
(
cd
,
nh1
,
nh2
))
and
(
np
.
linalg
.
norm
(
pos
[
nh1
]
-
pos
[
cd
])
>
np
.
linalg
.
norm
(
pos
[
nh2
]
-
pos
[
cd
]))
):
pos
[
nh1
],
pos
[
nh2
]
=
pos
[
nh2
].
copy
(),
pos
[
nh1
].
copy
()
mask
[
nh1
],
mask
[
nh2
]
=
mask
[
nh2
].
copy
(),
mask
[
nh1
].
copy
()
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
...
...
openfold/data/templates.py
View file @
57f869d6
...
...
@@ -14,8 +14,10 @@
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import
abc
import
dataclasses
import
datetime
import
functools
import
glob
import
json
import
logging
...
...
@@ -65,10 +67,6 @@ class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class
PdbIdError
(
PrefilterError
):
"""An error indicating that the hit PDB ID was identical to the query."""
class
AlignRatioError
(
PrefilterError
):
"""An error indicating that the hit align ratio to the query was too small."""
...
...
@@ -188,7 +186,6 @@ def _assess_hhsearch_hit(
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
max_subsequence_ratio
:
float
=
0.95
,
...
...
@@ -202,7 +199,6 @@ def _assess_hhsearch_hit(
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
...
...
@@ -214,7 +210,6 @@ def _assess_hhsearch_hit(
Raises:
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
...
...
@@ -239,10 +234,6 @@ def _assess_hhsearch_hit(
f
"(
{
release_date_cutoff
}
)."
)
if
query_pdb_code
is
not
None
:
if
query_pdb_code
.
lower
()
==
hit_pdb_code
.
lower
():
raise
PdbIdError
(
"PDB code identical to Query PDB code."
)
if
align_ratio
<=
min_align_ratio
:
raise
AlignRatioError
(
"Proportion of residues aligned to query too small. "
...
...
@@ -408,9 +399,10 @@ def _realign_pdb_template_to_query(
)
try
:
(
old_aligned_template
,
new_aligned_template
),
_
=
parsers
.
parse_a3m
(
parsed_a3m
=
parsers
.
parse_a3m
(
aligner
.
align
([
old_template_sequence
,
new_template_sequence
])
)
old_aligned_template
,
new_aligned_template
=
parsed_a3m
.
sequences
except
Exception
as
e
:
raise
QueryToTemplateAlignError
(
"Could not align old template %s to template %s (%s_%s). Error: %s"
...
...
@@ -752,7 +744,6 @@ class SingleHitResult:
def
_prefilter_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
...
...
@@ -773,7 +764,6 @@ def _prefilter_hit(
hit
=
hit
,
hit_pdb_code
=
hit_pdb_code
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
release_dates
=
release_dates
,
release_date_cutoff
=
max_template_date
,
)
...
...
@@ -781,9 +771,7 @@ def _prefilter_hit(
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
PdbIdError
,
DuplicateError
)
):
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
# In strict mode we treat some prefilter cases as errors.
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
...
...
@@ -792,9 +780,16 @@ def _prefilter_hit(
return
PrefilterResult
(
valid
=
True
,
error
=
None
,
warning
=
None
)
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
_read_file
(
path
):
with
open
(
path
,
'r'
)
as
f
:
file_data
=
f
.
read
()
return
file_data
def
_process_single_hit
(
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
...
...
@@ -832,8 +827,7 @@ def _process_single_hit(
template_sequence
,
)
# Fail if we can't find the mmCIF file.
with
open
(
cif_path
,
"r"
)
as
cif_file
:
cif_string
=
cif_file
.
read
()
cif_string
=
_read_file
(
cif_path
)
parsing_result
=
mmcif_parsing
.
parse
(
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
...
...
@@ -866,6 +860,10 @@ def _process_single_hit(
kalign_binary_path
=
kalign_binary_path
,
_zero_center_positions
=
_zero_center_positions
,
)
if
hit
.
sum_probs
is
None
:
features
[
"template_sum_probs"
]
=
[
0
]
else
:
features
[
"template_sum_probs"
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
...
...
@@ -920,8 +918,8 @@ class TemplateSearchResult:
warnings
:
Sequence
[
str
]
class
TemplateHitFeaturizer
:
"""A class for turning
hhr hits to
template features."""
class
TemplateHitFeaturizer
(
abc
.
ABC
)
:
"""A
n abstract base
class for turning template
hits to
features."""
def
__init__
(
self
,
mmcif_dir
:
str
,
...
...
@@ -993,10 +991,18 @@ class TemplateHitFeaturizer:
self
.
_shuffle_top_k_prefiltered
=
_shuffle_top_k_prefiltered
self
.
_zero_center_positions
=
_zero_center_positions
@
abc
.
abstractmethod
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
query_pdb_code
:
Optional
[
str
],
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
...
...
@@ -1025,7 +1031,6 @@ class TemplateHitFeaturizer:
for
hit
in
hits
:
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
max_template_date
=
template_cutoff_date
,
release_dates
=
self
.
_release_dates
,
...
...
@@ -1105,3 +1110,88 @@ class TemplateHitFeaturizer:
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
already_seen
=
set
()
errors
=
[]
warnings
=
[]
if
not
hits
or
hits
[
0
].
sum_probs
is
None
:
sorted_hits
=
hits
else
:
sorted_hits
=
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
)
for
hit
in
sorted_hits
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
debug
(
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
"template_aatype"
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
"template_all_atom_masks"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
"template_all_atom_positions"
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
"template_domain_names"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sequence"
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
"template_sum_probs"
:
np
.
array
([
0
],
dtype
=
np
.
float32
),
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
,
)
openfold/data/tools/hhblits.py
View file @
57f869d6
...
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
os
import
subprocess
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
from
openfold.data.tools
import
utils
...
...
@@ -99,9 +99,9 @@ class HHBlits:
self
.
p
=
p
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
def
query
(
self
,
input_fasta_path
:
str
)
->
List
[
Mapping
[
str
,
Any
]
]
:
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
db_cmd
=
[]
...
...
@@ -172,4 +172,4 @@ class HHBlits:
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
,
)
return
raw_output
return
[
raw_output
]
openfold/data/tools/hhsearch.py
View file @
57f869d6
...
...
@@ -20,6 +20,7 @@ import os
import
subprocess
from
typing
import
Sequence
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
...
...
@@ -62,9 +63,17 @@ class HHSearch:
f
"Could not find HHsearch database
{
database_path
}
"
)
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
...
...
@@ -104,3 +113,11 @@ class HHSearch:
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
return
hhr
def
get_template_hits
(
self
,
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool"""
del
input_sequence
# Used by hmmsearch but not needed for hhsearch
return
parsers
.
parse_hhr
(
output_string
)
openfold/data/tools/hmmbuild.py
0 → 100644
View file @
57f869d6
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
openfold.data.tools
import
utils
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
openfold/data/tools/hmmsearch.py
0 → 100644
View file @
57f869d6
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
openfold.data
import
parsers
from
openfold.data.tools
import
hmmbuild
from
openfold.data.tools
import
utils
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
'Could not find hmmsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'sto'
@
property
def
input_format
(
self
)
->
str
:
return
'sto'
def
query
(
self
,
msa_sto
:
str
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
)
def
query_with_hmm
(
self
,
hmm
:
str
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
out_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
])
logging
.
info
(
'Launching sub-process %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
def
get_template_hits
(
self
,
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
parsers
.
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parsers
.
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
openfold/data/tools/jackhmmer.py
View file @
57f869d6
...
...
@@ -23,6 +23,7 @@ import subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
...
...
@@ -93,10 +94,13 @@ class Jackhmmer:
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.sto"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
...
...
@@ -167,8 +171,11 @@ class Jackhmmer:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
if
(
max_sequences
is
None
):
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
sto
=
sto
,
...
...
@@ -180,10 +187,16 @@ class Jackhmmer:
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
,
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
...
...
@@ -217,12 +230,20 @@ class Jackhmmer:
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
)
)
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
future
=
next_future
# Do not set next_future for the last chunk so that this works
# even for databases with only 1 chunk
if
(
i
<
self
.
num_streamed_chunks
):
future
=
next_future
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
return
chunked_output
openfold/data/tools/kalign.py
View file @
57f869d6
...
...
@@ -72,7 +72,7 @@ class Kalign:
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
with
utils
.
tmpdir_manager
(
base_dir
=
"/tmp"
)
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"input.fasta"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.a3m"
)
...
...
openfold/model/structure_module.py
View file @
57f869d6
...
...
@@ -16,7 +16,7 @@
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
...
...
@@ -151,6 +151,40 @@ class AngleResnet(nn.Module):
return
unnormalized_s
,
s
class
PointProjection
(
nn
.
Module
):
def
__init__
(
self
,
c_hidden
:
int
,
num_points
:
int
,
no_heads
:
int
return_local_points
:
bool
=
False
,
):
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
linear
=
Linear
(
c_hidden
,
3
*
num_points
)
def
forward
(
self
,
activations
:
torch
.
Tensor
,
rigids
:
Rigid3Array
,
)
->
Union
[
Vec3Array
,
Tuple
[
Vec3Array
,
Vec3Array
]]:
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
points_local
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
,
)
points_local
=
torch
.
split
(
points_local
,
3
,
dim
=-
1
)
points_local
=
Vec3Array
(
*
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
if
(
self
.
return_local_points
):
return
points_global
,
points_local
return
points_global
class
InvariantPointAttention
(
nn
.
Module
):
"""
Implements Algorithm 22.
...
...
@@ -200,13 +234,23 @@ class InvariantPointAttention(nn.Module):
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
)
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
hpq
=
self
.
no_heads
*
self
.
no_qk_points
*
3
self
.
linear_q_points
=
Linear
(
self
.
c_s
,
hpq
)
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
)
hpkv
=
self
.
no_heads
*
(
self
.
no_qk_points
+
self
.
no_v_points
)
*
3
self
.
linear_kv_points
=
Linear
(
self
.
c_s
,
hpkv
)
self
.
linear_k_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
self
.
no_heads
,
)
hpv
=
self
.
no_heads
*
self
.
no_v_points
*
3
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
no_v_points
self
.
no_heads
,
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
...
...
@@ -257,35 +301,14 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
# [*, N_res, H * P_q * 3]
q_pts
=
self
.
linear_q_points
(
s
)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts
=
torch
.
split
(
q_pts
,
q_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
q_pts
=
torch
.
stack
(
q_pts
,
dim
=-
1
)
q_pts
=
r
[...,
None
].
apply
(
q_pts
)
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts
=
self
.
linear_kv_points
(
s
)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts
=
torch
.
split
(
kv_pts
,
kv_pts
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
kv_pts
=
torch
.
stack
(
kv_pts
,
dim
=-
1
)
kv_pts
=
r
[...,
None
].
apply
(
kv_pts
)
# [*, N_res, H, P_qk]
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# [*, N_res, H,
(
P_q
+ P_v)
, 3]
k
v
_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
)
)
# [*, N_res, H, P_q
k
, 3]
k_pts
=
self
.
linear_k_points
(
s
,
r
)
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
kv_pts
,
[
self
.
no_qk_points
,
self
.
no_v_points
],
dim
=-
2
)
# [*, N_res, H, P_v, 3]
v_pts
=
self
.
linear_v_points
(
s
,
r
)
##########################
# Compute attention scores
...
...
@@ -302,8 +325,8 @@ class InvariantPointAttention(nn.Module):
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
pt_att
*
*
2
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
pt_att
*
pt_att
+
self
.
eps
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
...
...
@@ -340,26 +363,20 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
sum
(
(
a
[...,
None
,
:,
:,
None
]
*
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
))[...,
None
,
:,
:]
),
dim
=-
2
,
# [*, N_res, H, P_v]
o_pt
=
v_pts
.
tensor_dot
(
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
))
o_pt
=
r
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
o_pt_norm
=
flatten_final_dims
(
torch
.
sqrt
(
torch
.
sum
(
o_pt
**
2
,
dim
=-
1
)
+
self
.
eps
),
2
)
# [*, N_res, H, P_v]
o_pt
=
r
[...,
None
,
None
].
apply_inverse_to_point
(
o_pt
)
# [*, N_res, H * P_v, 3]
o_pt
=
o_pt
.
reshape
(
*
o_pt
.
shape
[:
-
3
],
-
1
,
3
)
o_pt
=
o_pt
.
reshape
(
o_pt
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, H * P_v]
o_pt_norm
=
o_pt
.
norm
(
self
.
eps
)
# [*, N_res, H, C_z]
o_pair
=
torch
.
matmul
(
a
.
transpose
(
-
2
,
-
3
),
z
.
to
(
dtype
=
a
.
dtype
))
...
...
@@ -370,7 +387,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
s
=
self
.
linear_out
(
torch
.
cat
(
(
o
,
*
torch
.
unbind
(
o_pt
,
dim
=-
1
)
,
o_pt_norm
,
o_pair
),
dim
=-
1
(
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
)
...
...
openfold/np/residue_constants.py
View file @
57f869d6
...
...
@@ -24,8 +24,6 @@ from importlib import resources
import
numpy
as
np
import
tree
# Internal import (35fd).
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca
=
3.80209737096
...
...
@@ -1309,3 +1307,179 @@ def aatype_to_str_sequence(aatype):
restypes_with_x
[
aatype
[
i
]]
for
i
in
range
(
len
(
aatype
))
])
### ALPHAFOLD MULTIMER STUFF ###
def
_make_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
restypes
:
residue_name
=
restype_1to3
[
residue_name
]
residue_chi_angles
=
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
np
.
array
(
chi_atom_indices
)
def
_make_renaming_matrices
():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3
=
[
restype_1to3
[
res
]
for
res
in
restypes
]
restype_3
+=
[
'UNK'
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
resname
,
swap
in
residue_atom_renaming_swaps
.
items
():
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
return
renaming_matrices
def
_make_restype_atom37_mask
():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
restypes
):
restype_name
=
restype_1to3
[
restype_letter
]
atom_names
=
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
return
restype_atom37_mask
def
_make_restype_atom14_mask
():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask
=
[]
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
return
restype_atom14_mask
def
_make_restype_atom37_to_atom14
():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
atom_types
])
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
return
restype_atom37_to_atom14
def
_make_restype_atom14_to_atom37
():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
for
rt
in
restypes
:
atom_names
=
restype_name_to_atom14_names
[
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
return
restype_atom14_to_atom37
def
_make_restype_atom14_is_ambiguous
():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
for
resname
,
swap
in
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
restype_order
[
restype_3to1
[
resname
]]
atom_idx1
=
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
return
restype_atom14_is_ambiguous
def
_make_restype_rigidgroup_base_atom37_idx
():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
# 0: backbone frame
base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
# 3: 'psi-group'
base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
# 4,5,6,7: 'chi1,2,3,4-group'
for
restype
,
restype_letter
in
enumerate
(
restypes
):
resname
=
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
chi_angles_mask
[
restype
][
chi_idx
]:
atom_names
=
chi_angles_atoms
[
resname
][
chi_idx
]
base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
# Translate atom names into atom37 indices.
lookuptable
=
atom_order
.
copy
()
lookuptable
[
''
]
=
0
restype_rigidgroup_base_atom37_idx
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])(
base_atom_names
)
return
restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES
=
_make_chi_atom_indices
()
RENAMING_MATRICES
=
_make_renaming_matrices
()
RESTYPE_ATOM14_TO_ATOM37
=
_make_restype_atom14_to_atom37
()
RESTYPE_ATOM37_TO_ATOM14
=
_make_restype_atom37_to_atom14
()
RESTYPE_ATOM37_MASK
=
_make_restype_atom37_mask
()
RESTYPE_ATOM14_MASK
=
_make_restype_atom14_mask
()
RESTYPE_ATOM14_IS_AMBIGUOUS
=
_make_restype_atom14_is_ambiguous
()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
=
_make_restype_rigidgroup_base_atom37_idx
()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
RESTYPE_RIGIDGROUP_MASK
[:,
0
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:,
3
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:
20
,
4
:]
=
chi_angles_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment