Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
alphafold2_jax
Commits
2f0d89e7
Commit
2f0d89e7
authored
Aug 24, 2023
by
zhuwenwen
Browse files
remove duplicated code
parent
a1597f3f
Changes
103
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
6630 deletions
+0
-6630
alphafold/data/pipeline.py
alphafold/data/pipeline.py
+0
-245
alphafold/data/pipeline_multimer.py
alphafold/data/pipeline_multimer.py
+0
-284
alphafold/data/templates.py
alphafold/data/templates.py
+0
-1010
alphafold/data/tools/__init__.py
alphafold/data/tools/__init__.py
+0
-14
alphafold/data/tools/hhblits.py
alphafold/data/tools/hhblits.py
+0
-155
alphafold/data/tools/hhsearch.py
alphafold/data/tools/hhsearch.py
+0
-107
alphafold/data/tools/hmmbuild.py
alphafold/data/tools/hmmbuild.py
+0
-138
alphafold/data/tools/hmmsearch.py
alphafold/data/tools/hmmsearch.py
+0
-131
alphafold/data/tools/jackhmmer.py
alphafold/data/tools/jackhmmer.py
+0
-211
alphafold/data/tools/kalign.py
alphafold/data/tools/kalign.py
+0
-104
alphafold/data/tools/utils.py
alphafold/data/tools/utils.py
+0
-40
alphafold/model/__init__.py
alphafold/model/__init__.py
+0
-14
alphafold/model/all_atom.py
alphafold/model/all_atom.py
+0
-1141
alphafold/model/all_atom_multimer.py
alphafold/model/all_atom_multimer.py
+0
-968
alphafold/model/all_atom_test.py
alphafold/model/all_atom_test.py
+0
-135
alphafold/model/common_modules.py
alphafold/model/common_modules.py
+0
-130
alphafold/model/config.py
alphafold/model/config.py
+0
-657
alphafold/model/data.py
alphafold/model/data.py
+0
-33
alphafold/model/features.py
alphafold/model/features.py
+0
-104
alphafold/model/folding.py
alphafold/model/folding.py
+0
-1009
No files found.
alphafold/data/pipeline.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the input features for the AlphaFold model."""
import
os
# Internal import (7716).
from
typing
import
Any
,
Mapping
,
MutableMapping
,
Optional
,
Sequence
,
Union
from
absl
import
logging
from
alphafold.common
import
residue_constants
from
alphafold.data
import
msa_identifiers
from
alphafold.data
import
parsers
from
alphafold.data
import
templates
from
alphafold.data.tools
import
hhblits
from
alphafold.data.tools
import
hhsearch
from
alphafold.data.tools
import
hmmsearch
from
alphafold.data.tools
import
jackhmmer
import
numpy
as
np
FeatureDict
=
MutableMapping
[
str
,
np
.
ndarray
]
TemplateSearcher
=
Union
[
hhsearch
.
HHSearch
,
hmmsearch
.
Hmmsearch
]
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
"""Constructs a feature dict of sequence features."""
features
=
{}
features
[
'aatype'
]
=
residue_constants
.
sequence_to_onehot
(
sequence
=
sequence
,
mapping
=
residue_constants
.
restype_order_with_x
,
map_unknown_to_x
=
True
)
features
[
'between_segment_residues'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'domain_name'
]
=
np
.
array
([
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'residue_index'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'seq_length'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'sequence'
]
=
np
.
array
([
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
features
def
make_msa_features
(
msas
:
Sequence
[
parsers
.
Msa
])
->
FeatureDict
:
"""Constructs a feature dict of MSA features."""
if
not
msas
:
raise
ValueError
(
'At least one MSA must be provided.'
)
int_msa
=
[]
deletion_matrix
=
[]
species_ids
=
[]
seen_sequences
=
set
()
for
msa_index
,
msa
in
enumerate
(
msas
):
if
not
msa
:
raise
ValueError
(
f
'MSA
{
msa_index
}
must contain at least one sequence.'
)
for
sequence_index
,
sequence
in
enumerate
(
msa
.
sequences
):
if
sequence
in
seen_sequences
:
continue
seen_sequences
.
add
(
sequence
)
int_msa
.
append
(
[
residue_constants
.
HHBLITS_AA_TO_ID
[
res
]
for
res
in
sequence
])
deletion_matrix
.
append
(
msa
.
deletion_matrix
[
sequence_index
])
identifiers
=
msa_identifiers
.
get_identifiers
(
msa
.
descriptions
[
sequence_index
])
species_ids
.
append
(
identifiers
.
species_id
.
encode
(
'utf-8'
))
num_res
=
len
(
msas
[
0
].
sequences
[
0
])
num_alignments
=
len
(
int_msa
)
features
=
{}
features
[
'deletion_matrix_int'
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
features
[
'msa'
]
=
np
.
array
(
int_msa
,
dtype
=
np
.
int32
)
features
[
'num_alignments'
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'msa_species_identifiers'
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object_
)
return
features
def
run_msa_tool
(
msa_runner
,
input_fasta_path
:
str
,
msa_out_path
:
str
,
msa_format
:
str
,
use_precomputed_msas
:
bool
,
max_sto_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
if
not
use_precomputed_msas
or
not
os
.
path
.
exists
(
msa_out_path
):
if
msa_format
==
'sto'
and
max_sto_sequences
is
not
None
:
result
=
msa_runner
.
query
(
input_fasta_path
,
max_sto_sequences
)[
0
]
# pytype: disable=wrong-arg-count
else
:
result
=
msa_runner
.
query
(
input_fasta_path
)[
0
]
with
open
(
msa_out_path
,
'w'
)
as
f
:
f
.
write
(
result
[
msa_format
])
else
:
logging
.
warning
(
'Reading MSA from file %s'
,
msa_out_path
)
if
msa_format
==
'sto'
and
max_sto_sequences
is
not
None
:
precomputed_msa
=
parsers
.
truncate_stockholm_msa
(
msa_out_path
,
max_sto_sequences
)
result
=
{
'sto'
:
precomputed_msa
}
else
:
with
open
(
msa_out_path
,
'r'
)
as
f
:
result
=
{
msa_format
:
f
.
read
()}
return
result
class
DataPipeline
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
uniref90_database_path
:
str
,
mgnify_database_path
:
str
,
bfd_database_path
:
Optional
[
str
],
uniclust30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
template_searcher
:
TemplateSearcher
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
mgnify_max_hits
:
int
=
501
,
uniref_max_hits
:
int
=
10000
,
use_precomputed_msas
:
bool
=
False
):
"""Initializes the data pipeline."""
self
.
_use_small_bfd
=
use_small_bfd
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
)
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
)
else
:
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uniclust30_database_path
])
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
)
self
.
template_searcher
=
template_searcher
self
.
template_featurizer
=
template_featurizer
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
use_precomputed_msas
=
use_precomputed_msas
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
)
->
FeatureDict
:
"""Runs alignment tools on the input sequence and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
input_fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
uniref90_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'uniref90_hits.sto'
)
jackhmmer_uniref90_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
uniref90_out_path
,
msa_format
=
'sto'
,
use_precomputed_msas
=
self
.
use_precomputed_msas
,
max_sto_sequences
=
self
.
uniref_max_hits
)
mgnify_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'mgnify_hits.sto'
)
jackhmmer_mgnify_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
mgnify_out_path
,
msa_format
=
'sto'
,
use_precomputed_msas
=
self
.
use_precomputed_msas
,
max_sto_sequences
=
self
.
mgnify_max_hits
)
msa_for_templates
=
jackhmmer_uniref90_result
[
'sto'
]
msa_for_templates
=
parsers
.
deduplicate_stockholm_msa
(
msa_for_templates
)
msa_for_templates
=
parsers
.
remove_empty_columns_from_stockholm_msa
(
msa_for_templates
)
if
self
.
template_searcher
.
input_format
==
'sto'
:
pdb_templates_result
=
self
.
template_searcher
.
query
(
msa_for_templates
)
elif
self
.
template_searcher
.
input_format
==
'a3m'
:
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
msa_for_templates
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
uniref90_msa_as_a3m
)
else
:
raise
ValueError
(
'Unrecognized template input format: '
f
'
{
self
.
template_searcher
.
input_format
}
'
)
pdb_hits_out_path
=
os
.
path
.
join
(
msa_output_dir
,
f
'pdb_hits.
{
self
.
template_searcher
.
output_format
}
'
)
with
open
(
pdb_hits_out_path
,
'w'
)
as
f
:
f
.
write
(
pdb_templates_result
)
uniref90_msa
=
parsers
.
parse_stockholm
(
jackhmmer_uniref90_result
[
'sto'
])
mgnify_msa
=
parsers
.
parse_stockholm
(
jackhmmer_mgnify_result
[
'sto'
])
pdb_template_hits
=
self
.
template_searcher
.
get_template_hits
(
output_string
=
pdb_templates_result
,
input_sequence
=
input_sequence
)
if
self
.
_use_small_bfd
:
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'small_bfd_hits.sto'
)
jackhmmer_small_bfd_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_small_bfd_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
'sto'
,
use_precomputed_msas
=
self
.
use_precomputed_msas
)
bfd_msa
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
])
else
:
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uniclust_hits.a3m'
)
hhblits_bfd_uniclust_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uniclust_runner
,
input_fasta_path
=
input_fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
'a3m'
,
use_precomputed_msas
=
self
.
use_precomputed_msas
)
bfd_msa
=
parsers
.
parse_a3m
(
hhblits_bfd_uniclust_result
[
'a3m'
])
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
hits
=
pdb_template_hits
)
sequence_features
=
make_sequence_features
(
sequence
=
input_sequence
,
description
=
input_description
,
num_res
=
num_res
)
msa_features
=
make_msa_features
((
uniref90_msa
,
bfd_msa
,
mgnify_msa
))
logging
.
info
(
'Uniref90 MSA size: %d sequences.'
,
len
(
uniref90_msa
))
logging
.
info
(
'BFD MSA size: %d sequences.'
,
len
(
bfd_msa
))
logging
.
info
(
'MGnify MSA size: %d sequences.'
,
len
(
mgnify_msa
))
logging
.
info
(
'Final (deduplicated) MSA size: %d sequences.'
,
msa_features
[
'num_alignments'
][
0
])
logging
.
info
(
'Total number of templates (NB: this can include bad '
'templates and is later filtered to top 4): %d.'
,
templates_result
.
features
[
'template_domain_names'
].
shape
[
0
])
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
alphafold/data/pipeline_multimer.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the features for the AlphaFold multimer model."""
import
collections
import
contextlib
import
copy
import
dataclasses
import
json
import
os
import
tempfile
from
typing
import
Mapping
,
MutableMapping
,
Sequence
from
absl
import
logging
from
alphafold.common
import
protein
from
alphafold.common
import
residue_constants
from
alphafold.data
import
feature_processing
from
alphafold.data
import
msa_pairing
from
alphafold.data
import
parsers
from
alphafold.data
import
pipeline
from
alphafold.data.tools
import
jackhmmer
import
numpy
as
np
# Internal import (7716).
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
_FastaChain
:
sequence
:
str
description
:
str
def
_make_chain_id_map
(
*
,
sequences
:
Sequence
[
str
],
descriptions
:
Sequence
[
str
],
)
->
Mapping
[
str
,
_FastaChain
]:
"""Makes a mapping from PDB-format chain ID to sequence and description."""
if
len
(
sequences
)
!=
len
(
descriptions
):
raise
ValueError
(
'sequences and descriptions must have equal length. '
f
'Got
{
len
(
sequences
)
}
!=
{
len
(
descriptions
)
}
.'
)
if
len
(
sequences
)
>
protein
.
PDB_MAX_CHAINS
:
raise
ValueError
(
'Cannot process more chains than the PDB format supports. '
f
'Got
{
len
(
sequences
)
}
chains.'
)
chain_id_map
=
{}
for
chain_id
,
sequence
,
description
in
zip
(
protein
.
PDB_CHAIN_IDS
,
sequences
,
descriptions
):
chain_id_map
[
chain_id
]
=
_FastaChain
(
sequence
=
sequence
,
description
=
description
)
return
chain_id_map
@
contextlib
.
contextmanager
def
temp_fasta_file
(
fasta_str
:
str
):
with
tempfile
.
NamedTemporaryFile
(
'w'
,
suffix
=
'.fasta'
)
as
fasta_file
:
fasta_file
.
write
(
fasta_str
)
fasta_file
.
seek
(
0
)
yield
fasta_file
.
name
def
convert_monomer_features
(
monomer_features
:
pipeline
.
FeatureDict
,
chain_id
:
str
)
->
pipeline
.
FeatureDict
:
"""Reshapes and modifies monomer features for multimer models."""
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
for
feature_name
,
feature
in
monomer_features
.
items
():
if
feature_name
in
unnecessary_leading_dim_feats
:
# asarray ensures it's a np.ndarray.
feature
=
np
.
asarray
(
feature
[
0
],
dtype
=
feature
.
dtype
)
elif
feature_name
==
'aatype'
:
# The multimer model performs the one-hot operation itself.
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
elif
feature_name
==
'template_aatype'
:
feature
=
np
.
argmax
(
feature
,
axis
=-
1
).
astype
(
np
.
int32
)
new_order_list
=
residue_constants
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature
=
np
.
take
(
new_order_list
,
feature
.
astype
(
np
.
int32
),
axis
=
0
)
elif
feature_name
==
'template_all_atom_masks'
:
feature_name
=
'template_all_atom_mask'
converted
[
feature_name
]
=
feature
return
converted
def
int_id_to_str_id
(
num
:
int
)
->
str
:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if
num
<=
0
:
raise
ValueError
(
f
'Only positive integers allowed, got
{
num
}
.'
)
num
=
num
-
1
# 1-based indexing.
output
=
[]
while
num
>=
0
:
output
.
append
(
chr
(
num
%
26
+
ord
(
'A'
)))
num
=
num
//
26
-
1
return
''
.
join
(
output
)
def
add_assembly_features
(
all_chain_features
:
MutableMapping
[
str
,
pipeline
.
FeatureDict
],
)
->
MutableMapping
[
str
,
pipeline
.
FeatureDict
]:
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id
=
{}
grouped_chains
=
collections
.
defaultdict
(
list
)
for
chain_id
,
chain_features
in
all_chain_features
.
items
():
seq
=
str
(
chain_features
[
'sequence'
])
if
seq
not
in
seq_to_entity_id
:
seq_to_entity_id
[
seq
]
=
len
(
seq_to_entity_id
)
+
1
grouped_chains
[
seq_to_entity_id
[
seq
]].
append
(
chain_features
)
new_all_chain_features
=
{}
chain_id
=
1
for
entity_id
,
group_chain_features
in
grouped_chains
.
items
():
for
sym_id
,
chain_features
in
enumerate
(
group_chain_features
,
start
=
1
):
new_all_chain_features
[
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
seq_length
=
chain_features
[
'seq_length'
]
chain_features
[
'asym_id'
]
=
chain_id
*
np
.
ones
(
seq_length
)
chain_features
[
'sym_id'
]
=
sym_id
*
np
.
ones
(
seq_length
)
chain_features
[
'entity_id'
]
=
entity_id
*
np
.
ones
(
seq_length
)
chain_id
+=
1
return
new_all_chain_features
def
pad_msa
(
np_example
,
min_num_seq
):
np_example
=
dict
(
np_example
)
num_seq
=
np_example
[
'msa'
].
shape
[
0
]
if
num_seq
<
min_num_seq
:
for
feat
in
(
'msa'
,
'deletion_matrix'
,
'bert_mask'
,
'msa_mask'
):
np_example
[
feat
]
=
np
.
pad
(
np_example
[
feat
],
((
0
,
min_num_seq
-
num_seq
),
(
0
,
0
)))
np_example
[
'cluster_bias_mask'
]
=
np
.
pad
(
np_example
[
'cluster_bias_mask'
],
((
0
,
min_num_seq
-
num_seq
),))
return
np_example
class
DataPipeline
:
"""Runs the alignment tools and assembles the input features."""
def
__init__
(
self
,
monomer_data_pipeline
:
pipeline
.
DataPipeline
,
jackhmmer_binary_path
:
str
,
uniprot_database_path
:
str
,
max_uniprot_hits
:
int
=
50000
,
use_precomputed_msas
:
bool
=
False
):
"""Initializes the data pipeline.
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
self
.
_uniprot_msa_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniprot_database_path
)
self
.
_max_uniprot_hits
=
max_uniprot_hits
self
.
use_precomputed_msas
=
use_precomputed_msas
def
_process_single_chain
(
self
,
chain_id
:
str
,
sequence
:
str
,
description
:
str
,
msa_output_dir
:
str
,
is_homomer_or_monomer
:
bool
)
->
pipeline
.
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>chain_
{
chain_id
}
\n
{
sequence
}
\n
'
chain_msa_output_dir
=
os
.
path
.
join
(
msa_output_dir
,
chain_id
)
if
not
os
.
path
.
exists
(
chain_msa_output_dir
):
os
.
makedirs
(
chain_msa_output_dir
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
logging
.
info
(
'Running monomer pipeline on chain %s: %s'
,
chain_id
,
description
)
chain_features
=
self
.
_monomer_data_pipeline
.
process
(
input_fasta_path
=
chain_fasta_path
,
msa_output_dir
=
chain_msa_output_dir
)
# We only construct the pairing features if there are 2 or more unique
# sequences.
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_fasta_path
,
chain_msa_output_dir
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
def
_all_seq_msa_features
(
self
,
input_fasta_path
,
msa_output_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
out_path
=
os
.
path
.
join
(
msa_output_dir
,
'uniprot_hits.sto'
)
result
=
pipeline
.
run_msa_tool
(
self
.
_uniprot_msa_runner
,
input_fasta_path
,
out_path
,
'sto'
,
self
.
use_precomputed_msas
)
msa
=
parsers
.
parse_stockholm
(
result
[
'sto'
])
msa
=
msa
.
truncate
(
max_seqs
=
self
.
_max_uniprot_hits
)
all_seq_features
=
pipeline
.
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
)
feats
=
{
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
if
k
in
valid_feats
}
return
feats
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
)
->
pipeline
.
FeatureDict
:
"""Runs alignment tools on the input sequences and creates features."""
with
open
(
input_fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
chain_id_map
=
_make_chain_id_map
(
sequences
=
input_seqs
,
descriptions
=
input_descs
)
chain_id_map_path
=
os
.
path
.
join
(
msa_output_dir
,
'chain_id_map.json'
)
with
open
(
chain_id_map_path
,
'w'
)
as
f
:
chain_id_map_dict
=
{
chain_id
:
dataclasses
.
asdict
(
fasta_chain
)
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
()}
json
.
dump
(
chain_id_map_dict
,
f
,
indent
=
4
,
sort_keys
=
True
)
all_chain_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
():
if
fasta_chain
.
sequence
in
sequence_features
:
all_chain_features
[
chain_id
]
=
copy
.
deepcopy
(
sequence_features
[
fasta_chain
.
sequence
])
continue
chain_features
=
self
.
_process_single_chain
(
chain_id
=
chain_id
,
sequence
=
fasta_chain
.
sequence
,
description
=
fasta_chain
.
description
,
msa_output_dir
=
msa_output_dir
,
is_homomer_or_monomer
=
is_homomer_or_monomer
)
chain_features
=
convert_monomer_features
(
chain_features
,
chain_id
=
chain_id
)
all_chain_features
[
chain_id
]
=
chain_features
sequence_features
[
fasta_chain
.
sequence
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing
.
pair_and_merge
(
all_chain_features
=
all_chain_features
)
# Pad MSA to avoid zero-sized extra_msa.
np_example
=
pad_msa
(
np_example
,
512
)
return
np_example
alphafold/data/templates.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import
abc
import
dataclasses
import
datetime
import
functools
import
glob
import
os
import
re
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
from
alphafold.common
import
residue_constants
from
alphafold.data
import
mmcif_parsing
from
alphafold.data
import
parsers
from
alphafold.data.tools
import
kalign
import
numpy
as
np
# Internal import (7716).
class
Error
(
Exception
):
"""Base class for exceptions."""
class
NoChainsError
(
Error
):
"""An error indicating that template mmCIF didn't have any chains."""
class
SequenceNotInTemplateError
(
Error
):
"""An error indicating that template mmCIF didn't contain the sequence."""
class
NoAtomDataInTemplateError
(
Error
):
"""An error indicating that template mmCIF didn't contain atom positions."""
class
TemplateAtomMaskAllZerosError
(
Error
):
"""An error indicating that template mmCIF had all atom positions masked."""
class
QueryToTemplateAlignError
(
Error
):
"""An error indicating that the query can't be aligned to the template."""
class
CaDistanceError
(
Error
):
"""An error indicating that a CA atom distance exceeds a threshold."""
class
MultipleChainsError
(
Error
):
"""An error indicating that multiple chains were found for a given ID."""
# Prefilter exceptions.
class
PrefilterError
(
Exception
):
"""A base class for template prefilter exceptions."""
class
DateError
(
PrefilterError
):
"""An error indicating that the hit date was after the max allowed date."""
class
AlignRatioError
(
PrefilterError
):
"""An error indicating that the hit align ratio to the query was too small."""
class
DuplicateError
(
PrefilterError
):
"""An error indicating that the hit was an exact subsequence of the query."""
class
LengthError
(
PrefilterError
):
"""An error indicating that the hit was too short."""
TEMPLATE_FEATURES
=
{
'template_aatype'
:
np
.
float32
,
'template_all_atom_masks'
:
np
.
float32
,
'template_all_atom_positions'
:
np
.
float32
,
'template_domain_names'
:
np
.
object
,
'template_sequence'
:
np
.
object
,
'template_sum_probs'
:
np
.
float32
,
}
def
_get_pdb_id_and_chain
(
hit
:
parsers
.
TemplateHit
)
->
Tuple
[
str
,
str
]:
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match
=
re
.
match
(
r
'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+'
,
hit
.
name
)
if
not
id_match
:
raise
ValueError
(
f
'hit.name did not start with PDBID_chain:
{
hit
.
name
}
'
)
pdb_id
,
chain_id
=
id_match
.
group
(
0
).
split
(
'_'
)
return
pdb_id
.
lower
(),
chain_id
def
_is_after_cutoff
(
pdb_id
:
str
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
Optional
[
datetime
.
datetime
])
->
bool
:
"""Checks if the template date is after the release date cutoff.
Args:
pdb_id: 4 letter pdb code.
release_dates: Dictionary mapping PDB ids to their structure release dates.
release_date_cutoff: Max release date that is valid for this query.
Returns:
True if the template release date is after the cutoff, False otherwise.
"""
if
release_date_cutoff
is
None
:
raise
ValueError
(
'The release_date_cutoff must not be None.'
)
if
pdb_id
in
release_dates
:
return
release_dates
[
pdb_id
]
>
release_date_cutoff
else
:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
return
False
def
_parse_obsolete
(
obsolete_file_path
:
str
)
->
Mapping
[
str
,
Optional
[
str
]]:
"""Parses the data file from PDB that lists which pdb_ids are obsolete."""
with
open
(
obsolete_file_path
)
as
f
:
result
=
{}
for
line
in
f
:
line
=
line
.
strip
()
# Format: Date From To
# 'OBSLTE 06-NOV-19 6G9Y' - Removed, rare
# 'OBSLTE 31-JUL-94 116L 216L' - Replaced, common
# 'OBSLTE 26-SEP-06 2H33 2JM5 2OWI' - Replaced by multiple, rare
if
line
.
startswith
(
'OBSLTE'
):
if
len
(
line
)
>
30
:
# Replaced by at least one structure.
from_id
=
line
[
20
:
24
].
lower
()
to_id
=
line
[
29
:
33
].
lower
()
result
[
from_id
]
=
to_id
elif
len
(
line
)
==
24
:
# Removed.
from_id
=
line
[
20
:
24
].
lower
()
result
[
from_id
]
=
None
return
result
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if
path
.
endswith
(
'txt'
):
release_dates
=
{}
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
pdb_id
,
date
=
line
.
split
(
':'
)
date
=
date
.
strip
()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates
[
pdb_id
.
strip
()]
=
datetime
.
datetime
(
year
=
int
(
date
[:
4
]),
month
=
int
(
date
[
5
:
7
]),
day
=
int
(
date
[
8
:
10
]))
return
release_dates
else
:
raise
ValueError
(
'Invalid format of the release date file %s.'
%
path
)
def
_assess_hhsearch_hit
(
hit
:
parsers
.
TemplateHit
,
hit_pdb_code
:
str
,
query_sequence
:
str
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
release_date_cutoff
:
datetime
.
datetime
,
max_subsequence_ratio
:
float
=
0.95
,
min_align_ratio
:
float
=
0.1
)
->
bool
:
"""Determines if template is valid (without parsing the template mmcif file).
Args:
hit: HhrHit for the template.
hit_pdb_code: The 4 letter pdb code of the template hit. This might be
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
max_subsequence_ratio: Exclude any exact matches with this much overlap.
min_align_ratio: Minimum overlap between the template and query.
Returns:
True if the hit passed the prefilter. Raises an exception otherwise.
Raises:
DateError: If the hit date was after the max allowed date.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
"""
aligned_cols
=
hit
.
aligned_cols
align_ratio
=
aligned_cols
/
len
(
query_sequence
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
'-'
,
''
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
raise
DateError
(
f
'Date (
{
release_dates
[
hit_pdb_code
]
}
) > max template date '
f
'(
{
release_date_cutoff
}
).'
)
if
align_ratio
<=
min_align_ratio
:
raise
AlignRatioError
(
'Proportion of residues aligned to query too small. '
f
'Align ratio:
{
align_ratio
}
.'
)
if
duplicate
:
raise
DuplicateError
(
'Template is an exact subsequence of query with large '
f
'coverage. Length ratio:
{
length_ratio
}
.'
)
if
len
(
template_sequence
)
<
10
:
raise
LengthError
(
f
'Template too short. Length:
{
len
(
template_sequence
)
}
.'
)
return
True
def
_find_template_in_pdb
(
template_chain_id
:
str
,
template_sequence
:
str
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
)
->
Tuple
[
str
,
str
,
int
]:
"""Tries to find the template chain in the given pdb file.
This method tries the three following things in order:
1. Tries if there is an exact match in both the chain ID and the sequence.
If yes, the chain sequence is returned. Otherwise:
2. Tries if there is an exact match only in the sequence.
If yes, the chain sequence is returned. Otherwise:
3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
If yes, the chain sequence is returned.
If none of these succeed, a SequenceNotInTemplateError is thrown.
Args:
template_chain_id: The template chain ID.
template_sequence: The template chain sequence.
mmcif_object: The PDB object to search for the template in.
Returns:
A tuple with:
* The chain sequence that was found to match the template in the PDB object.
* The ID of the chain that is being returned.
* The offset where the template sequence starts in the chain sequence.
Raises:
SequenceNotInTemplateError: If no match is found after the steps described
above.
"""
# Try if there is an exact match in both the chain ID and the (sub)sequence.
pdb_id
=
mmcif_object
.
file_id
chain_sequence
=
mmcif_object
.
chain_to_seqres
.
get
(
template_chain_id
)
if
chain_sequence
and
(
template_sequence
in
chain_sequence
):
logging
.
info
(
'Found an exact template match %s_%s.'
,
pdb_id
,
template_chain_id
)
mapping_offset
=
chain_sequence
.
find
(
template_sequence
)
return
chain_sequence
,
template_chain_id
,
mapping_offset
# Try if there is an exact match in the (sub)sequence only.
for
chain_id
,
chain_sequence
in
mmcif_object
.
chain_to_seqres
.
items
():
if
chain_sequence
and
(
template_sequence
in
chain_sequence
):
logging
.
info
(
'Found a sequence-only match %s_%s.'
,
pdb_id
,
chain_id
)
mapping_offset
=
chain_sequence
.
find
(
template_sequence
)
return
chain_sequence
,
chain_id
,
mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex
=
[
'.'
if
aa
==
'X'
else
'(?:%s|X)'
%
aa
for
aa
in
template_sequence
]
regex
=
re
.
compile
(
''
.
join
(
regex
))
for
chain_id
,
chain_sequence
in
mmcif_object
.
chain_to_seqres
.
items
():
match
=
re
.
search
(
regex
,
chain_sequence
)
if
match
:
logging
.
info
(
'Found a fuzzy sequence-only match %s_%s.'
,
pdb_id
,
chain_id
)
mapping_offset
=
match
.
start
()
return
chain_sequence
,
chain_id
,
mapping_offset
# No hits, raise an error.
raise
SequenceNotInTemplateError
(
'Could not find the template sequence in %s_%s. Template sequence: %s, '
'chain_to_seqres: %s'
%
(
pdb_id
,
template_chain_id
,
template_sequence
,
mmcif_object
.
chain_to_seqres
))
def
_realign_pdb_template_to_query
(
old_template_sequence
:
str
,
template_chain_id
:
str
,
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
old_mapping
:
Mapping
[
int
,
int
],
kalign_binary_path
:
str
)
->
Tuple
[
str
,
Mapping
[
int
,
int
]]:
"""Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need
to perform a realignment to the actual sequence that is in the mmCIF file.
This method performs such realignment, but returns the new sequence and
mapping only if the sequence in the mmCIF file is 90% identical to the old
sequence.
Note that the old_template_sequence comes from the hit, and contains only that
part of the chain that matches with the query while the new_template_sequence
is the full chain.
Args:
old_template_sequence: The template sequence that was returned by the PDB
template search (typically done using HHSearch).
template_chain_id: The template chain id was returned by the PDB template
search (typically done using HHSearch). This is used to find the right
chain in the mmcif_object chain_to_seqres mapping.
mmcif_object: A mmcif_object which holds the actual template data.
old_mapping: A mapping from the query sequence to the template sequence.
This mapping will be used to compute the new mapping from the query
sequence to the actual mmcif_object template sequence by aligning the
old_template_sequence and the actual template sequence.
kalign_binary_path: The path to a kalign executable.
Returns:
A tuple (new_template_sequence, new_query_to_template_mapping) where:
* new_template_sequence is the actual template sequence that was found in
the mmcif_object.
* new_query_to_template_mapping is the new mapping from the query to the
actual template found in the mmcif_object.
Raises:
QueryToTemplateAlignError:
* If there was an error thrown by the alignment tool.
* Or if the actual template sequence differs by more than 10% from the
old_template_sequence.
"""
aligner
=
kalign
.
Kalign
(
binary_path
=
kalign_binary_path
)
new_template_sequence
=
mmcif_object
.
chain_to_seqres
.
get
(
template_chain_id
,
''
)
# Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one.
if
not
new_template_sequence
:
if
len
(
mmcif_object
.
chain_to_seqres
)
==
1
:
logging
.
info
(
'Could not find %s in %s, but there is only 1 sequence, so '
'using that one.'
,
template_chain_id
,
mmcif_object
.
file_id
)
new_template_sequence
=
list
(
mmcif_object
.
chain_to_seqres
.
values
())[
0
]
else
:
raise
QueryToTemplateAlignError
(
f
'Could not find chain
{
template_chain_id
}
in
{
mmcif_object
.
file_id
}
. '
'If there are no mmCIF parsing errors, it is possible it was not a '
'protein chain.'
)
try
:
parsed_a3m
=
parsers
.
parse_a3m
(
aligner
.
align
([
old_template_sequence
,
new_template_sequence
]))
old_aligned_template
,
new_aligned_template
=
parsed_a3m
.
sequences
except
Exception
as
e
:
raise
QueryToTemplateAlignError
(
'Could not align old template %s to template %s (%s_%s). Error: %s'
%
(
old_template_sequence
,
new_template_sequence
,
mmcif_object
.
file_id
,
template_chain_id
,
str
(
e
)))
logging
.
info
(
'Old aligned template: %s
\n
New aligned template: %s'
,
old_aligned_template
,
new_aligned_template
)
old_to_new_template_mapping
=
{}
old_template_index
=
-
1
new_template_index
=
-
1
num_same
=
0
for
old_template_aa
,
new_template_aa
in
zip
(
old_aligned_template
,
new_aligned_template
):
if
old_template_aa
!=
'-'
:
old_template_index
+=
1
if
new_template_aa
!=
'-'
:
new_template_index
+=
1
if
old_template_aa
!=
'-'
and
new_template_aa
!=
'-'
:
old_to_new_template_mapping
[
old_template_index
]
=
new_template_index
if
old_template_aa
==
new_template_aa
:
num_same
+=
1
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
if
float
(
num_same
)
/
min
(
len
(
old_template_sequence
),
len
(
new_template_sequence
))
<
0.9
:
raise
QueryToTemplateAlignError
(
'Insufficient similarity of the sequence in the database: %s to the '
'actual sequence in the mmCIF file %s_%s: %s. We require at least '
'90 %% similarity wrt to the shorter of the sequences. This is not a '
'problem unless you think this is a template that should be included.'
%
(
old_template_sequence
,
mmcif_object
.
file_id
,
template_chain_id
,
new_template_sequence
))
new_query_to_template_mapping
=
{}
for
query_index
,
old_template_index
in
old_mapping
.
items
():
new_query_to_template_mapping
[
query_index
]
=
(
old_to_new_template_mapping
.
get
(
old_template_index
,
-
1
))
new_template_sequence
=
new_template_sequence
.
replace
(
'-'
,
''
)
return
new_template_sequence
,
new_query_to_template_mapping
def
_check_residue_distances
(
all_positions
:
np
.
ndarray
,
all_positions_mask
:
np
.
ndarray
,
max_ca_ca_distance
:
float
):
"""Checks if the distance between unmasked neighbor residues is ok."""
ca_position
=
residue_constants
.
atom_order
[
'CA'
]
prev_is_unmasked
=
False
prev_calpha
=
None
for
i
,
(
coords
,
mask
)
in
enumerate
(
zip
(
all_positions
,
all_positions_mask
)):
this_is_unmasked
=
bool
(
mask
[
ca_position
])
if
this_is_unmasked
:
this_calpha
=
coords
[
ca_position
]
if
prev_is_unmasked
:
distance
=
np
.
linalg
.
norm
(
this_calpha
-
prev_calpha
)
if
distance
>
max_ca_ca_distance
:
raise
CaDistanceError
(
'The distance between residues %d and %d is %f > limit %f.'
%
(
i
,
i
+
1
,
distance
,
max_ca_ca_distance
))
prev_calpha
=
this_calpha
prev_is_unmasked
=
this_is_unmasked
def
_get_atom_positions
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
auth_chain_id
])
relevant_chains
=
[
c
for
c
in
mmcif_object
.
structure
.
get_chains
()
if
c
.
id
==
auth_chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
auth_chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
all_positions
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
,
3
])
all_positions_mask
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
int64
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
auth_chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coordinates of the selenium atom in the sulphur column.
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
# Fix naming errors in arginine residues where NH2 is incorrectly
# assigned to be closer to CD than NH1.
cd
=
residue_constants
.
atom_order
[
'CD'
]
nh1
=
residue_constants
.
atom_order
[
'NH1'
]
nh2
=
residue_constants
.
atom_order
[
'NH2'
]
if
(
res
.
get_resname
()
==
'ARG'
and
all
(
mask
[
atom_index
]
for
atom_index
in
(
cd
,
nh1
,
nh2
))
and
(
np
.
linalg
.
norm
(
pos
[
nh1
]
-
pos
[
cd
])
>
np
.
linalg
.
norm
(
pos
[
nh2
]
-
pos
[
cd
]))):
pos
[
nh1
],
pos
[
nh2
]
=
pos
[
nh2
].
copy
(),
pos
[
nh1
].
copy
()
mask
[
nh1
],
mask
[
nh2
]
=
mask
[
nh2
].
copy
(),
mask
[
nh1
].
copy
()
all_positions
[
res_index
]
=
pos
all_positions_mask
[
res_index
]
=
mask
_check_residue_distances
(
all_positions
,
all_positions_mask
,
max_ca_ca_distance
)
return
all_positions
,
all_positions_mask
def
_extract_template_features
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
pdb_id
:
str
,
mapping
:
Mapping
[
int
,
int
],
template_sequence
:
str
,
query_sequence
:
str
,
template_chain_id
:
str
,
kalign_binary_path
:
str
)
->
Tuple
[
Dict
[
str
,
Any
],
Optional
[
str
]]:
"""Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide
with their corresponding residue in the query sequence, according to the
alignment mapping provided.
Args:
mmcif_object: mmcif_parsing.MmcifObject representing the template.
pdb_id: PDB code for the template.
mapping: Dictionary mapping indices in the query sequence to indices in
the template sequence.
template_sequence: String describing the amino acid sequence for the
template protein.
query_sequence: String describing the amino acid sequence for the query
protein.
template_chain_id: String ID describing which chain in the structure proto
should be used.
kalign_binary_path: The path to a kalign executable used for template
realignment.
Returns:
A tuple with:
* A dictionary containing the extra features derived from the template
protein structure.
* A warning message if the hit was realigned to the actual mmCIF sequence.
Otherwise None.
Raises:
NoChainsError: If the mmcif object doesn't contain any chains.
SequenceNotInTemplateError: If the given chain id / sequence can't
be found in the mmcif object.
QueryToTemplateAlignError: If the actual template in the mmCIF file
can't be aligned to the query.
NoAtomDataInTemplateError: If the mmcif object doesn't contain
atom positions.
TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
unmasked residues.
"""
if
mmcif_object
is
None
or
not
mmcif_object
.
chain_to_seqres
:
raise
NoChainsError
(
'No chains in PDB: %s_%s'
%
(
pdb_id
,
template_chain_id
))
warning
=
None
try
:
seqres
,
chain_id
,
mapping_offset
=
_find_template_in_pdb
(
template_chain_id
=
template_chain_id
,
template_sequence
=
template_sequence
,
mmcif_object
=
mmcif_object
)
except
SequenceNotInTemplateError
:
# If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object.
chain_id
=
template_chain_id
warning
=
(
f
'The exact sequence
{
template_sequence
}
was not found in '
f
'
{
pdb_id
}
_
{
chain_id
}
. Realigning the template to the actual sequence.'
)
logging
.
warning
(
warning
)
# This throws an exception if it fails to realign the hit.
seqres
,
mapping
=
_realign_pdb_template_to_query
(
old_template_sequence
=
template_sequence
,
template_chain_id
=
template_chain_id
,
mmcif_object
=
mmcif_object
,
old_mapping
=
mapping
,
kalign_binary_path
=
kalign_binary_path
)
logging
.
info
(
'Sequence in %s_%s: %s successfully realigned to %s'
,
pdb_id
,
chain_id
,
template_sequence
,
seqres
)
# The template sequence changed.
template_sequence
=
seqres
# No mapping offset, the query is aligned to the actual sequence.
mapping_offset
=
0
try
:
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions
,
all_atom_mask
=
_get_atom_positions
(
mmcif_object
,
chain_id
,
max_ca_ca_distance
=
150.0
)
except
(
CaDistanceError
,
KeyError
)
as
ex
:
raise
NoAtomDataInTemplateError
(
'Could not get atom data (%s_%s): %s'
%
(
pdb_id
,
chain_id
,
str
(
ex
))
)
from
ex
all_atom_positions
=
np
.
split
(
all_atom_positions
,
all_atom_positions
.
shape
[
0
])
all_atom_masks
=
np
.
split
(
all_atom_mask
,
all_atom_mask
.
shape
[
0
])
output_templates_sequence
=
[]
templates_all_atom_positions
=
[]
templates_all_atom_masks
=
[]
for
_
in
query_sequence
:
# Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions
.
append
(
np
.
zeros
((
residue_constants
.
atom_type_num
,
3
)))
templates_all_atom_masks
.
append
(
np
.
zeros
(
residue_constants
.
atom_type_num
))
output_templates_sequence
.
append
(
'-'
)
for
k
,
v
in
mapping
.
items
():
template_index
=
v
+
mapping_offset
templates_all_atom_positions
[
k
]
=
all_atom_positions
[
template_index
][
0
]
templates_all_atom_masks
[
k
]
=
all_atom_masks
[
template_index
][
0
]
output_templates_sequence
[
k
]
=
template_sequence
[
v
]
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if
np
.
sum
(
templates_all_atom_masks
)
<
5
:
raise
TemplateAtomMaskAllZerosError
(
'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d'
%
(
pdb_id
,
chain_id
,
min
(
mapping
.
values
())
+
mapping_offset
,
max
(
mapping
.
values
())
+
mapping_offset
))
output_templates_sequence
=
''
.
join
(
output_templates_sequence
)
templates_aatype
=
residue_constants
.
sequence_to_onehot
(
output_templates_sequence
,
residue_constants
.
HHBLITS_AA_TO_ID
)
return
(
{
'template_all_atom_positions'
:
np
.
array
(
templates_all_atom_positions
),
'template_all_atom_masks'
:
np
.
array
(
templates_all_atom_masks
),
'template_sequence'
:
output_templates_sequence
.
encode
(),
'template_aatype'
:
np
.
array
(
templates_aatype
),
'template_domain_names'
:
f
'
{
pdb_id
.
lower
()
}
_
{
chain_id
}
'
.
encode
(),
},
warning
)
def
_build_query_to_hit_index_mapping
(
hit_query_sequence
:
str
,
hit_sequence
:
str
,
indices_hit
:
Sequence
[
int
],
indices_query
:
Sequence
[
int
],
original_query_sequence
:
str
)
->
Mapping
[
int
,
int
]:
"""Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap
characters. hit_query_sequence contains only the part of the original query
sequence that matched the hit. When interpreting the indices from the .hhr, we
need to correct for this to recover a mapping from original query sequence to
the hit sequence.
Args:
hit_query_sequence: The portion of the query sequence that is in the .hhr
hit
hit_sequence: The portion of the hit sequence that is in the .hhr
indices_hit: The indices for each aminoacid relative to the hit sequence
indices_query: The indices for each aminoacid relative to the original query
sequence
original_query_sequence: String describing the original query sequence.
Returns:
Dictionary with indices in the original query sequence as keys and indices
in the hit sequence as values.
"""
# If the hit is empty (no aligned residues), return empty mapping
if
not
hit_query_sequence
:
return
{}
# Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence
=
hit_query_sequence
.
replace
(
'-'
,
''
)
hit_sequence
=
hit_sequence
.
replace
(
'-'
,
''
)
hhsearch_query_offset
=
original_query_sequence
.
find
(
hhsearch_query_sequence
)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx
=
min
(
x
for
x
in
indices_hit
if
x
>
-
1
)
fixed_indices_hit
=
[
x
-
min_idx
if
x
>
-
1
else
-
1
for
x
in
indices_hit
]
min_idx
=
min
(
x
for
x
in
indices_query
if
x
>
-
1
)
fixed_indices_query
=
[
x
-
min_idx
if
x
>
-
1
else
-
1
for
x
in
indices_query
]
# Zip the corrected indices, ignore case where both seqs have gap characters.
mapping
=
{}
for
q_i
,
q_t
in
zip
(
fixed_indices_query
,
fixed_indices_hit
):
if
q_t
!=
-
1
and
q_i
!=
-
1
:
if
(
q_t
>=
len
(
hit_sequence
)
or
q_i
+
hhsearch_query_offset
>=
len
(
original_query_sequence
)):
continue
mapping
[
q_i
+
hhsearch_query_offset
]
=
q_t
return
mapping
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
SingleHitResult
:
features
:
Optional
[
Mapping
[
str
,
Any
]]
error
:
Optional
[
str
]
warning
:
Optional
[
str
]
@
functools
.
lru_cache
(
16
,
typed
=
False
)
def
_read_file
(
path
):
with
open
(
path
,
'r'
)
as
f
:
file_data
=
f
.
read
()
return
file_data
def
_process_single_hit
(
query_sequence
:
str
,
hit
:
parsers
.
TemplateHit
,
mmcif_dir
:
str
,
max_template_date
:
datetime
.
datetime
,
release_dates
:
Mapping
[
str
,
datetime
.
datetime
],
obsolete_pdbs
:
Mapping
[
str
,
Optional
[
str
]],
kalign_binary_path
:
str
,
strict_error_check
:
bool
=
False
)
->
SingleHitResult
:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code
,
hit_chain_id
=
_get_pdb_id_and_chain
(
hit
)
# This hit has been removed (obsoleted) from PDB, skip it.
if
hit_pdb_code
in
obsolete_pdbs
and
obsolete_pdbs
[
hit_pdb_code
]
is
None
:
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
f
'Hit
{
hit_pdb_code
}
is obsolete.'
)
if
hit_pdb_code
not
in
release_dates
:
if
hit_pdb_code
in
obsolete_pdbs
:
hit_pdb_code
=
obsolete_pdbs
[
hit_pdb_code
]
# Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
try
:
_assess_hhsearch_hit
(
hit
=
hit
,
hit_pdb_code
=
hit_pdb_code
,
query_sequence
=
query_sequence
,
release_dates
=
release_dates
,
release_date_cutoff
=
max_template_date
)
except
PrefilterError
as
e
:
msg
=
f
'hit
{
hit_pdb_code
}
_
{
hit_chain_id
}
did not pass prefilter:
{
str
(
e
)
}
'
logging
.
info
(
msg
)
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
# In strict mode we treat some prefilter cases as errors.
return
SingleHitResult
(
features
=
None
,
error
=
msg
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
mapping
=
_build_query_to_hit_index_mapping
(
hit
.
query
,
hit
.
hit_sequence
,
hit
.
indices_hit
,
hit
.
indices_query
,
query_sequence
)
# The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score).
template_sequence
=
hit
.
hit_sequence
.
replace
(
'-'
,
''
)
cif_path
=
os
.
path
.
join
(
mmcif_dir
,
hit_pdb_code
+
'.cif'
)
logging
.
debug
(
'Reading PDB entry from %s. Query: %s, template: %s'
,
cif_path
,
query_sequence
,
template_sequence
)
# Fail if we can't find the mmCIF file.
cif_string
=
_read_file
(
cif_path
)
parsing_result
=
mmcif_parsing
.
parse
(
file_id
=
hit_pdb_code
,
mmcif_string
=
cif_string
)
if
parsing_result
.
mmcif_object
is
not
None
:
hit_release_date
=
datetime
.
datetime
.
strptime
(
parsing_result
.
mmcif_object
.
header
[
'release_date'
],
'%Y-%m-%d'
)
if
hit_release_date
>
max_template_date
:
error
=
(
'Template %s date (%s) > max template date (%s).'
%
(
hit_pdb_code
,
hit_release_date
,
max_template_date
))
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
else
:
logging
.
debug
(
error
)
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
None
)
try
:
features
,
realign_warning
=
_extract_template_features
(
mmcif_object
=
parsing_result
.
mmcif_object
,
pdb_id
=
hit_pdb_code
,
mapping
=
mapping
,
template_sequence
=
template_sequence
,
query_sequence
=
query_sequence
,
template_chain_id
=
hit_chain_id
,
kalign_binary_path
=
kalign_binary_path
)
if
hit
.
sum_probs
is
None
:
features
[
'template_sum_probs'
]
=
[
0
]
else
:
features
[
'template_sum_probs'
]
=
[
hit
.
sum_probs
]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant.
return
SingleHitResult
(
features
=
features
,
error
=
None
,
warning
=
realign_warning
)
except
(
NoChainsError
,
NoAtomDataInTemplateError
,
TemplateAtomMaskAllZerosError
)
as
e
:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning
=
(
'%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
index
,
str
(
e
),
parsing_result
.
errors
))
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
warning
,
warning
=
None
)
else
:
return
SingleHitResult
(
features
=
None
,
error
=
None
,
warning
=
warning
)
except
Error
as
e
:
error
=
(
'%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
%
(
hit_pdb_code
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
index
,
str
(
e
),
parsing_result
.
errors
))
return
SingleHitResult
(
features
=
None
,
error
=
error
,
warning
=
None
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateSearchResult
:
features
:
Mapping
[
str
,
Any
]
errors
:
Sequence
[
str
]
warnings
:
Sequence
[
str
]
class
TemplateHitFeaturizer
(
abc
.
ABC
):
"""An abstract base class for turning template hits to template features."""
def
__init__
(
self
,
mmcif_dir
:
str
,
max_template_date
:
str
,
max_hits
:
int
,
kalign_binary_path
:
str
,
release_dates_path
:
Optional
[
str
],
obsolete_pdbs_path
:
Optional
[
str
],
strict_error_check
:
bool
=
False
):
"""Initializes the Template Search.
Args:
mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
is found by HHSearch, this directory is used to retrieve the template
data.
max_template_date: The maximum date permitted for template structures. No
template with date higher than this date will be returned. In ISO8601
date format, YYYY-MM-DD.
max_hits: The maximum number of templates that will be returned.
kalign_binary_path: The path to a kalign executable used for template
realignment.
release_dates_path: An optional path to a file with a mapping from PDB IDs
to their release dates. Thanks to this we don't have to redundantly
parse mmCIF files to get that information.
obsolete_pdbs_path: An optional path to a file containing a mapping from
obsolete PDB IDs to the PDB IDs of their replacements.
strict_error_check: If True, then the following will be treated as errors:
* If any template date is after the max_template_date.
* If any template has identical PDB ID to the query.
* If any template is a duplicate of the query.
* Any feature computation errors.
"""
self
.
_mmcif_dir
=
mmcif_dir
if
not
glob
.
glob
(
os
.
path
.
join
(
self
.
_mmcif_dir
,
'*.cif'
)):
logging
.
error
(
'Could not find CIFs in %s'
,
self
.
_mmcif_dir
)
raise
ValueError
(
f
'Could not find CIFs in
{
self
.
_mmcif_dir
}
'
)
try
:
self
.
_max_template_date
=
datetime
.
datetime
.
strptime
(
max_template_date
,
'%Y-%m-%d'
)
except
ValueError
:
raise
ValueError
(
'max_template_date must be set and have format YYYY-MM-DD.'
)
self
.
_max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
if
release_dates_path
:
logging
.
info
(
'Using precomputed release dates %s.'
,
release_dates_path
)
self
.
_release_dates
=
_parse_release_dates
(
release_dates_path
)
else
:
self
.
_release_dates
=
{}
if
obsolete_pdbs_path
:
logging
.
info
(
'Using precomputed obsolete pdbs %s.'
,
obsolete_pdbs_path
)
self
.
_obsolete_pdbs
=
_parse_obsolete
(
obsolete_pdbs_path
)
else
:
self
.
_obsolete_pdbs
=
{}
@
abc
.
abstractmethod
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
])
->
TemplateSearchResult
:
"""Computes the templates for given query sequence."""
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
"""A class for turning a3m hits from hhsearch to template features."""
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
])
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
'Searching for template for: %s'
,
query_sequence
)
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
num_hits
=
0
errors
=
[]
warnings
=
[]
for
hit
in
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
):
# We got all the templates we wanted, stop processing hits.
if
num_hits
>=
self
.
_max_hits
:
break
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
info
(
'Skipped invalid hit %s, error: %s, warning: %s'
,
hit
.
name
,
result
.
error
,
result
.
warning
)
else
:
# Increment the hit counter, since we got features out of this hit.
num_hits
+=
1
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
for
name
in
template_features
:
if
num_hits
>
0
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
# Make sure the feature has correct dtype even if empty.
template_features
[
name
]
=
np
.
array
([],
dtype
=
TEMPLATE_FEATURES
[
name
])
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
class
HmmsearchHitFeaturizer
(
TemplateHitFeaturizer
):
"""A class for turning a3m hits from hmmsearch to template features."""
def
get_templates
(
self
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
])
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
'Searching for template for: %s'
,
query_sequence
)
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
already_seen
=
set
()
errors
=
[]
warnings
=
[]
if
not
hits
or
hits
[
0
].
sum_probs
is
None
:
sorted_hits
=
hits
else
:
sorted_hits
=
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
)
for
hit
in
sorted_hits
:
# We got all the templates we wanted, stop processing hits.
if
len
(
already_seen
)
>=
self
.
_max_hits
:
break
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
if
result
.
error
:
errors
.
append
(
result
.
error
)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if
result
.
warning
:
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
logging
.
debug
(
'Skipped invalid hit %s, error: %s, warning: %s'
,
hit
.
name
,
result
.
error
,
result
.
warning
)
else
:
already_seen_key
=
result
.
features
[
'template_sequence'
]
if
already_seen_key
in
already_seen
:
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
for
name
in
template_features
:
template_features
[
name
]
=
np
.
stack
(
template_features
[
name
],
axis
=
0
).
astype
(
TEMPLATE_FEATURES
[
name
])
else
:
num_res
=
len
(
query_sequence
)
# Construct a default template with all zeros.
template_features
=
{
'template_aatype'
:
np
.
zeros
(
(
1
,
num_res
,
len
(
residue_constants
.
restypes_with_x_and_gap
)),
np
.
float32
),
'template_all_atom_masks'
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
),
np
.
float32
),
'template_all_atom_positions'
:
np
.
zeros
(
(
1
,
num_res
,
residue_constants
.
atom_type_num
,
3
),
np
.
float32
),
'template_domain_names'
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
'template_sequence'
:
np
.
array
([
''
.
encode
()],
dtype
=
np
.
object
),
'template_sum_probs'
:
np
.
array
([
0
],
dtype
=
np
.
float32
)
}
return
TemplateSearchResult
(
features
=
template_features
,
errors
=
errors
,
warnings
=
warnings
)
alphafold/data/tools/__init__.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python wrappers for third party tools."""
alphafold/data/tools/hhblits.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHblits from Python."""
import
glob
import
os
import
subprocess
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
from
absl
import
logging
from
alphafold.data.tools
import
utils
# Internal import (7716).
_HHBLITS_DEFAULT_P
=
20
_HHBLITS_DEFAULT_Z
=
500
class
HHBlits
:
"""Python wrapper of the HHblits binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
n_cpu
:
int
=
24
,
n_iter
:
int
=
3
,
e_value
:
float
=
0.001
,
maxseq
:
int
=
1_000_000
,
realign_max
:
int
=
100_000
,
maxfilt
:
int
=
100_000
,
min_prefilter_hits
:
int
=
1000
,
all_seqs
:
bool
=
False
,
alt
:
Optional
[
int
]
=
None
,
p
:
int
=
_HHBLITS_DEFAULT_P
,
z
:
int
=
_HHBLITS_DEFAULT_Z
):
"""Initializes the Python HHblits wrapper.
Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
databases
=
databases
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
logging
.
error
(
'Could not find HHBlits database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find HHBlits database
{
database_path
}
'
)
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
self
.
e_value
=
e_value
self
.
maxseq
=
maxseq
self
.
realign_max
=
realign_max
self
.
maxfilt
=
maxfilt
self
.
min_prefilter_hits
=
min_prefilter_hits
self
.
all_seqs
=
all_seqs
self
.
alt
=
alt
self
.
p
=
p
self
.
z
=
z
def
query
(
self
,
input_fasta_path
:
str
)
->
List
[
Mapping
[
str
,
Any
]]:
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.a3m'
)
db_cmd
=
[]
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'-d'
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
'-i'
,
input_fasta_path
,
'-cpu'
,
str
(
self
.
n_cpu
),
'-oa3m'
,
a3m_path
,
'-o'
,
'/dev/null'
,
'-n'
,
str
(
self
.
n_iter
),
'-e'
,
str
(
self
.
e_value
),
'-maxseq'
,
str
(
self
.
maxseq
),
'-realign_max'
,
str
(
self
.
realign_max
),
'-maxfilt'
,
str
(
self
.
maxfilt
),
'-min_prefilter_hits'
,
str
(
self
.
min_prefilter_hits
)]
if
self
.
all_seqs
:
cmd
+=
[
'-all'
]
if
self
.
alt
:
cmd
+=
[
'-alt'
,
str
(
self
.
alt
)]
if
self
.
p
!=
_HHBLITS_DEFAULT_P
:
cmd
+=
[
'-p'
,
str
(
self
.
p
)]
if
self
.
z
!=
_HHBLITS_DEFAULT_Z
:
cmd
+=
[
'-Z'
,
str
(
self
.
z
)]
cmd
+=
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'HHblits query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
# Logs have a 15k character limit, so log HHblits error line by line.
logging
.
error
(
'HHblits failed. HHblits stderr begin:'
)
for
error_line
in
stderr
.
decode
(
'utf-8'
).
splitlines
():
if
error_line
.
strip
():
logging
.
error
(
error_line
.
strip
())
logging
.
error
(
'HHblits stderr end'
)
raise
RuntimeError
(
'HHblits failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
500_000
].
decode
(
'utf-8'
)))
with
open
(
a3m_path
)
as
f
:
a3m
=
f
.
read
()
raw_output
=
dict
(
a3m
=
a3m
,
output
=
stdout
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
return
[
raw_output
]
alphafold/data/tools/hhsearch.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHsearch from Python."""
import
glob
import
os
import
subprocess
from
typing
import
Sequence
from
absl
import
logging
from
alphafold.data
import
parsers
from
alphafold.data.tools
import
utils
# Internal import (7716).
class
HHSearch
:
"""Python wrapper of the HHsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
maxseq
:
int
=
1_000_000
):
"""Initializes the Python HHsearch wrapper.
Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
databases
=
databases
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
logging
.
error
(
'Could not find HHsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find HHsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'hhr'
@
property
def
input_format
(
self
)
->
str
:
return
'a3m'
def
query
(
self
,
a3m
:
str
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.a3m'
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hhr'
)
with
open
(
input_path
,
'w'
)
as
f
:
f
.
write
(
a3m
)
db_cmd
=
[]
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'-d'
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
'-i'
,
input_path
,
'-o'
,
hhr_path
,
'-maxseq'
,
str
(
self
.
maxseq
)
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'HHsearch query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
# Stderr is truncated to prevent proto size errors in Beam.
raise
RuntimeError
(
'HHSearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
100_000
].
decode
(
'utf-8'
)))
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
return
hhr
def
get_template_hits
(
self
,
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
del
input_sequence
# Used by hmmseach but not needed for hhsearch.
return
parsers
.
parse_hhr
(
output_string
)
alphafold/data/tools/hmmbuild.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
import
os
import
re
import
subprocess
from
absl
import
logging
from
alphafold.data.tools
import
utils
# Internal import (7716).
class
Hmmbuild
(
object
):
"""Python wrapper of the hmmbuild binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
singlemx
:
bool
=
False
):
"""Initializes the Python hmmbuild wrapper.
Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.
Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
singlemx
=
singlemx
def
build_profile_from_sto
(
self
,
sto
:
str
,
model_construction
=
'fast'
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
return
self
.
_build_profile
(
sto
,
model_construction
=
model_construction
)
def
build_profile_from_a3m
(
self
,
a3m
:
str
)
->
str
:
"""Builds a HHM for the aligned sequences given as an A3M string.
Args:
a3m: A string with the aligned sequences in the A3M format.
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
"""
lines
=
[]
for
line
in
a3m
.
splitlines
():
if
not
line
.
startswith
(
'>'
):
line
=
re
.
sub
(
'[a-z]+'
,
''
,
line
)
# Remove inserted residues.
lines
.
append
(
line
+
'
\n
'
)
msa
=
''
.
join
(
lines
)
return
self
.
_build_profile
(
msa
,
model_construction
=
'fast'
)
def
_build_profile
(
self
,
msa
:
str
,
model_construction
:
str
=
'fast'
)
->
str
:
"""Builds a HMM for the aligned sequences given as an MSA string.
Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').
Returns:
A string with the profile in the HMM format.
Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if
model_construction
not
in
{
'hand'
,
'fast'
}:
raise
ValueError
(
f
'Invalid model_construction
{
model_construction
}
- only'
'hand and fast supported.'
)
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_query
=
os
.
path
.
join
(
query_tmp_dir
,
'query.msa'
)
output_hmm_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.hmm'
)
with
open
(
input_query
,
'w'
)
as
f
:
f
.
write
(
msa
)
cmd
=
[
self
.
binary_path
]
# If adding flags, we have to do so before the output and input:
if
model_construction
==
'hand'
:
cmd
.
append
(
f
'--
{
model_construction
}
'
)
if
self
.
singlemx
:
cmd
.
append
(
'--singlemx'
)
cmd
.
extend
([
'--amino'
,
output_hmm_path
,
input_query
,
])
logging
.
info
(
'Launching subprocess %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'hmmbuild query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'hmmbuild stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'hmmbuild failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_hmm_path
,
encoding
=
'utf-8'
)
as
f
:
hmm
=
f
.
read
()
return
hmm
alphafold/data/tools/hmmsearch.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""
import
os
import
subprocess
from
typing
import
Optional
,
Sequence
from
absl
import
logging
from
alphafold.data
import
parsers
from
alphafold.data.tools
import
hmmbuild
from
alphafold.data.tools
import
utils
# Internal import (7716).
class
Hmmsearch
(
object
):
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.
Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self
.
binary_path
=
binary_path
self
.
hmmbuild_runner
=
hmmbuild
.
Hmmbuild
(
binary_path
=
hmmbuild_binary_path
)
self
.
database_path
=
database_path
if
flags
is
None
:
# Default hmmsearch run settings.
flags
=
[
'--F1'
,
'0.1'
,
'--F2'
,
'0.1'
,
'--F3'
,
'0.1'
,
'--incE'
,
'100'
,
'-E'
,
'100'
,
'--domE'
,
'100'
,
'--incdomE'
,
'100'
]
self
.
flags
=
flags
if
not
os
.
path
.
exists
(
self
.
database_path
):
logging
.
error
(
'Could not find hmmsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find hmmsearch database
{
database_path
}
'
)
@
property
def
output_format
(
self
)
->
str
:
return
'sto'
@
property
def
input_format
(
self
)
->
str
:
return
'sto'
def
query
(
self
,
msa_sto
:
str
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
)
def
query_with_hmm
(
self
,
hmm
:
str
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
out_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
cmd
=
[
self
.
binary_path
,
'--noali'
,
# Don't include the alignment in stdout.
'--cpu'
,
'8'
]
# If adding flags, we have to do so before the output and input:
if
self
.
flags
:
cmd
.
extend
(
self
.
flags
)
cmd
.
extend
([
'-A'
,
out_path
,
hmm_input_path
,
self
.
database_path
,
])
logging
.
info
(
'Launching sub-process %s'
,
cmd
)
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'hmmsearch (
{
os
.
path
.
basename
(
self
.
database_path
)
}
) query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'hmmsearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
out_path
)
as
f
:
out_msa
=
f
.
read
()
return
out_msa
def
get_template_hits
(
self
,
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
parsers
.
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parsers
.
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
alphafold/data/tools/jackhmmer.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run Jackhmmer from Python."""
from
concurrent
import
futures
import
glob
import
os
import
subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
absl
import
logging
from
alphafold.data
import
parsers
from
alphafold.data.tools
import
utils
# Internal import (7716).
class
Jackhmmer
:
"""Python wrapper of the Jackhmmer binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
,
database_path
:
str
,
n_cpu
:
int
=
24
,
n_iter
:
int
=
1
,
e_value
:
float
=
0.0001
,
z_value
:
Optional
[
int
]
=
None
,
get_tblout
:
bool
=
False
,
filter_f1
:
float
=
0.0005
,
filter_f2
:
float
=
0.00005
,
filter_f3
:
float
=
0.0000005
,
incdom_e
:
Optional
[
float
]
=
None
,
dom_e
:
Optional
[
float
]
=
None
,
num_streamed_chunks
:
Optional
[
int
]
=
None
,
streaming_callback
:
Optional
[
Callable
[[
int
],
None
]]
=
None
):
"""Initializes the Python Jackhmmer wrapper.
Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value, see Jackhmmer docs for more details.
get_tblout: Whether to save tblout string.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
dom_e: Domain e-value criteria for inclusion in tblout.
num_streamed_chunks: Number of database chunks to stream over.
streaming_callback: Callback function run after each chunk iteration with
the iteration number as argument.
"""
self
.
binary_path
=
binary_path
self
.
database_path
=
database_path
self
.
num_streamed_chunks
=
num_streamed_chunks
if
not
os
.
path
.
exists
(
self
.
database_path
)
and
num_streamed_chunks
is
None
:
logging
.
error
(
'Could not find Jackhmmer database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find Jackhmmer database
{
database_path
}
'
)
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
self
.
e_value
=
e_value
self
.
z_value
=
z_value
self
.
filter_f1
=
filter_f1
self
.
filter_f2
=
filter_f2
self
.
filter_f3
=
filter_f3
self
.
incdom_e
=
incdom_e
self
.
dom_e
=
dom_e
self
.
get_tblout
=
get_tblout
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags
=
[
# Don't pollute stdout with Jackhmmer output.
'-o'
,
'/dev/null'
,
'-A'
,
sto_path
,
'--noali'
,
'--F1'
,
str
(
self
.
filter_f1
),
'--F2'
,
str
(
self
.
filter_f2
),
'--F3'
,
str
(
self
.
filter_f3
),
'--incE'
,
str
(
self
.
e_value
),
# Report only sequences with E-values <= x in per-sequence output.
'-E'
,
str
(
self
.
e_value
),
'--cpu'
,
str
(
self
.
n_cpu
),
'-N'
,
str
(
self
.
n_iter
)
]
if
self
.
get_tblout
:
tblout_path
=
os
.
path
.
join
(
query_tmp_dir
,
'tblout.txt'
)
cmd_flags
.
extend
([
'--tblout'
,
tblout_path
])
if
self
.
z_value
:
cmd_flags
.
extend
([
'-Z'
,
str
(
self
.
z_value
)])
if
self
.
dom_e
is
not
None
:
cmd_flags
.
extend
([
'--domE'
,
str
(
self
.
dom_e
)])
if
self
.
incdom_e
is
not
None
:
cmd_flags
.
extend
([
'--incdomE'
,
str
(
self
.
incdom_e
)])
cmd
=
[
self
.
binary_path
]
+
cmd_flags
+
[
input_fasta_path
,
database_path
]
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'Jackhmmer (
{
os
.
path
.
basename
(
database_path
)
}
) query'
):
_
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'Jackhmmer failed
\n
stderr:
\n
%s
\n
'
%
stderr
.
decode
(
'utf-8'
))
# Get e-values for each target name
tbl
=
''
if
self
.
get_tblout
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
if
max_sequences
is
None
:
with
open
(
sto_path
)
as
f
:
sto
=
f
.
read
()
else
:
sto
=
parsers
.
truncate_stockholm_msa
(
sto_path
,
max_sequences
)
raw_output
=
dict
(
sto
=
sto
,
tbl
=
tbl
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
return
raw_output
def
query
(
self
,
input_fasta_path
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
Sequence
[
Mapping
[
str
,
Any
]]:
"""Queries the database using Jackhmmer."""
if
self
.
num_streamed_chunks
is
None
:
single_chunk_result
=
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
,
max_sequences
)
return
[
single_chunk_result
]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
'
{
self
.
database_path
}
.
{
db_idx
}
'
db_local_chunk
=
lambda
db_idx
:
f
'/tmp/ramdisk/
{
db_basename
}
.
{
db_idx
}
'
# Remove existing files to prevent OOM
for
f
in
glob
.
glob
(
db_local_chunk
(
'[0-9]*'
)):
try
:
os
.
remove
(
f
)
except
OSError
:
print
(
f
'OSError while deleting
{
f
}
'
)
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
as
executor
:
chunked_output
=
[]
for
i
in
range
(
1
,
self
.
num_streamed_chunks
+
1
):
# Copy the chunk locally
if
i
==
1
:
future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
),
db_local_chunk
(
i
))
if
i
<
self
.
num_streamed_chunks
:
next_future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
+
1
),
db_local_chunk
(
i
+
1
))
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
),
max_sequences
))
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
# Do not set next_future for the last chunk so that this works even for
# databases with only 1 chunk.
if
i
<
self
.
num_streamed_chunks
:
future
=
next_future
if
self
.
streaming_callback
:
self
.
streaming_callback
(
i
)
return
chunked_output
alphafold/data/tools/kalign.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for Kalign."""
import
os
import
subprocess
from
typing
import
Sequence
from
absl
import
logging
from
alphafold.data.tools
import
utils
# Internal import (7716).
def
_to_a3m
(
sequences
:
Sequence
[
str
])
->
str
:
"""Converts sequences to an a3m file."""
names
=
[
'sequence %d'
%
i
for
i
in
range
(
1
,
len
(
sequences
)
+
1
)]
a3m
=
[]
for
sequence
,
name
in
zip
(
sequences
,
names
):
a3m
.
append
(
u
'>'
+
name
+
u
'
\n
'
)
a3m
.
append
(
sequence
+
u
'
\n
'
)
return
''
.
join
(
a3m
)
class
Kalign
:
"""Python wrapper of the Kalign binary."""
def
__init__
(
self
,
*
,
binary_path
:
str
):
"""Initializes the Python Kalign wrapper.
Args:
binary_path: The path to the Kalign binary.
Raises:
RuntimeError: If Kalign binary not found within the path.
"""
self
.
binary_path
=
binary_path
def
align
(
self
,
sequences
:
Sequence
[
str
])
->
str
:
"""Aligns the sequences and returns the alignment in A3M string.
Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.
Returns:
A string with the alignment in a3m format.
Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging
.
info
(
'Aligning %d sequences'
,
len
(
sequences
))
for
s
in
sequences
:
if
len
(
s
)
<
6
:
raise
ValueError
(
'Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).'
%
(
s
,
len
(
s
)))
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
'input.fasta'
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.a3m'
)
with
open
(
input_fasta_path
,
'w'
)
as
f
:
f
.
write
(
_to_a3m
(
sequences
))
cmd
=
[
self
.
binary_path
,
'-i'
,
input_fasta_path
,
'-o'
,
output_a3m_path
,
'-format'
,
'fasta'
,
]
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'Kalign query'
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'Kalign stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
if
retcode
:
raise
RuntimeError
(
'Kalign failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
with
open
(
output_a3m_path
)
as
f
:
a3m
=
f
.
read
()
return
a3m
alphafold/data/tools/utils.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import
contextlib
import
shutil
import
tempfile
import
time
from
typing
import
Optional
from
absl
import
logging
@
contextlib
.
contextmanager
def
tmpdir_manager
(
base_dir
:
Optional
[
str
]
=
None
):
"""Context manager that deletes a temporary directory on exit."""
tmpdir
=
tempfile
.
mkdtemp
(
dir
=
base_dir
)
try
:
yield
tmpdir
finally
:
shutil
.
rmtree
(
tmpdir
,
ignore_errors
=
True
)
@
contextlib
.
contextmanager
def
timing
(
msg
:
str
):
logging
.
info
(
'Started %s'
,
msg
)
tic
=
time
.
time
()
yield
toc
=
time
.
time
()
logging
.
info
(
'Finished %s in %.3f seconds'
,
msg
,
toc
-
tic
)
alphafold/model/__init__.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alphafold model."""
alphafold/model/all_atom.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations.
Generally we employ two different representations for all atom coordinates,
one is atom37 where each heavy atom corresponds to a given position in a 37
dimensional array, This mapping is non amino acid specific, but each slot
corresponds to an atom of a given name, for example slot 12 always corresponds
to 'C delta 1', positions that are not present for a given amino acid are
zeroed out and denoted by a mask.
The other representation we employ is called atom14, this is a more dense way
of representing atoms with 14 slots. Here a given slot will correspond to a
different kind of atom depending on amino acid type, for example slot 5
corresponds to 'N delta 2' for Aspargine, but to 'C delta 1' for Isoleucine.
14 is chosen because it is the maximum number of heavy atoms for any standard
amino acid.
The order of slots can be found in 'residue_constants.residue_atoms'.
Internally the model uses the atom14 representation because it is
computationally more efficient.
The internal atom14 representation is turned into the atom37 at the output of
the network to facilitate easier conversion to existing protein datastructures.
"""
from
typing
import
Dict
,
Optional
from
alphafold.common
import
residue_constants
from
alphafold.model
import
r3
from
alphafold.model
import
utils
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
residue_constants
.
restypes
:
residue_name
=
residue_constants
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
residue_constants
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
residue_constants
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
jnp
.
asarray
(
chi_atom_indices
)
def
atom14_to_atom37
(
atom14_data
:
jnp
.
ndarray
,
# (N, 14, ...)
batch
:
Dict
[
str
,
jnp
.
ndarray
]
)
->
jnp
.
ndarray
:
# (N, 37, ...)
"""Convert atom14 to atom37 representation."""
assert
len
(
atom14_data
.
shape
)
in
[
2
,
3
]
assert
'residx_atom37_to_atom14'
in
batch
assert
'atom37_atom_exists'
in
batch
atom37_data
=
utils
.
batched_gather
(
atom14_data
,
batch
[
'residx_atom37_to_atom14'
],
batch_dims
=
1
)
if
len
(
atom14_data
.
shape
)
==
2
:
atom37_data
*=
batch
[
'atom37_atom_exists'
]
elif
len
(
atom14_data
.
shape
)
==
3
:
atom37_data
*=
batch
[
'atom37_atom_exists'
][:,
:,
None
].
astype
(
atom37_data
.
dtype
)
return
atom37_data
def
atom37_to_atom14
(
atom37_data
:
jnp
.
ndarray
,
# (N, 37, ...)
batch
:
Dict
[
str
,
jnp
.
ndarray
])
->
jnp
.
ndarray
:
# (N, 14, ...)
"""Convert atom14 to atom37 representation."""
assert
len
(
atom37_data
.
shape
)
in
[
2
,
3
]
assert
'residx_atom14_to_atom37'
in
batch
assert
'atom14_atom_exists'
in
batch
atom14_data
=
utils
.
batched_gather
(
atom37_data
,
batch
[
'residx_atom14_to_atom37'
],
batch_dims
=
1
)
if
len
(
atom37_data
.
shape
)
==
2
:
atom14_data
*=
batch
[
'atom14_atom_exists'
].
astype
(
atom14_data
.
dtype
)
elif
len
(
atom37_data
.
shape
)
==
3
:
atom14_data
*=
batch
[
'atom14_atom_exists'
][:,
:,
None
].
astype
(
atom14_data
.
dtype
)
return
atom14_data
def
atom37_to_frames
(
aatype
:
jnp
.
ndarray
,
# (...)
all_atom_positions
:
jnp
.
ndarray
,
# (..., 37, 3)
all_atom_mask
:
jnp
.
ndarray
,
# (..., 37)
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Computes the frames for the up to 8 rigid groups for each residue.
The rigid groups are defined by the possible torsions in a given amino acid.
We group the atoms according to their dependence on the torsion angles into
"rigid groups". E.g., the position of atoms in the chi2-group depend on
chi1 and chi2, but do not depend on chi3 or chi4.
Jumper et al. (2021) Suppl. Table 2 and corresponding text.
Args:
aatype: Amino acid type, given as array with integers.
all_atom_positions: atom37 representation of all atom coordinates.
all_atom_mask: atom37 representation of mask on all atom coordinates.
Returns:
Dictionary containing:
* 'rigidgroups_gt_frames': 8 Frames corresponding to 'all_atom_positions'
represented as flat 12 dimensional array.
* 'rigidgroups_gt_exists': Mask denoting whether the atom positions for
the given frame are available in the ground truth, e.g. if they were
resolved in the experiment.
* 'rigidgroups_group_exists': Mask denoting whether given group is in
principle present for given amino acid type.
* 'rigidgroups_group_is_ambiguous': Mask denoting whether frame is
affected by naming ambiguity.
* 'rigidgroups_alt_gt_frames': 8 Frames with alternative atom renaming
corresponding to 'all_atom_positions' represented as flat
12 dimensional array.
"""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
aatype_in_shape
=
aatype
.
shape
# If there is a batch axis, just flatten it away, and reshape everything
# back at the end of the function.
aatype
=
jnp
.
reshape
(
aatype
,
[
-
1
])
all_atom_positions
=
jnp
.
reshape
(
all_atom_positions
,
[
-
1
,
37
,
3
])
all_atom_mask
=
jnp
.
reshape
(
all_atom_mask
,
[
-
1
,
37
])
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
# 0: backbone frame
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
# 3: 'psi-group'
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
# 4,5,6,7: 'chi1,2,3,4-group'
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
resname
=
residue_constants
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
residue_constants
.
chi_angles_mask
[
restype
][
chi_idx
]:
atom_names
=
residue_constants
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
# Create mask for existing rigid groups.
restype_rigidgroup_mask
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
restype_rigidgroup_mask
[:,
0
]
=
1
restype_rigidgroup_mask
[:,
3
]
=
1
restype_rigidgroup_mask
[:
20
,
4
:]
=
residue_constants
.
chi_angles_mask
# Translate atom names into atom37 indices.
lookuptable
=
residue_constants
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
restype_rigidgroup_base_atom37_idx
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])(
restype_rigidgroup_base_atom_names
)
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx
=
utils
.
batched_gather
(
restype_rigidgroup_base_atom37_idx
,
aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos
=
utils
.
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
1
)
# Compute the Rigids.
gt_frames
=
r3
.
rigids_from_3_points
(
point_on_neg_x_axis
=
r3
.
vecs_from_tensor
(
base_atom_pos
[:,
:,
0
,
:]),
origin
=
r3
.
vecs_from_tensor
(
base_atom_pos
[:,
:,
1
,
:]),
point_on_xy_plane
=
r3
.
vecs_from_tensor
(
base_atom_pos
[:,
:,
2
,
:])
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists
=
utils
.
batched_gather
(
restype_rigidgroup_mask
,
aatype
)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist
=
utils
.
batched_gather
(
# shape (N, 8, 3)
all_atom_mask
.
astype
(
jnp
.
float32
),
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
1
)
gt_exists
=
jnp
.
min
(
gt_atoms_exist
,
axis
=-
1
)
*
group_exists
# (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
8
,
1
,
1
])
rots
[
0
,
0
,
0
]
=
-
1
rots
[
0
,
2
,
2
]
=
-
1
gt_frames
=
r3
.
rigids_mul_rots
(
gt_frames
,
r3
.
rots_from_tensor3x3
(
rots
))
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
restype_rigidgroup_rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
21
,
8
,
1
,
1
])
for
resname
,
_
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
restype
=
residue_constants
.
restype_order
[
residue_constants
.
restype_3to1
[
resname
]]
chi_idx
=
int
(
sum
(
residue_constants
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous
=
utils
.
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
)
residx_rigidgroup_ambiguity_rot
=
utils
.
batched_gather
(
restype_rigidgroup_rots
,
aatype
)
# Create the alternative ground truth frames.
alt_gt_frames
=
r3
.
rigids_mul_rots
(
gt_frames
,
r3
.
rots_from_tensor3x3
(
residx_rigidgroup_ambiguity_rot
))
gt_frames_flat12
=
r3
.
rigids_to_tensor_flat12
(
gt_frames
)
alt_gt_frames_flat12
=
r3
.
rigids_to_tensor_flat12
(
alt_gt_frames
)
# reshape back to original residue layout
gt_frames_flat12
=
jnp
.
reshape
(
gt_frames_flat12
,
aatype_in_shape
+
(
8
,
12
))
gt_exists
=
jnp
.
reshape
(
gt_exists
,
aatype_in_shape
+
(
8
,))
group_exists
=
jnp
.
reshape
(
group_exists
,
aatype_in_shape
+
(
8
,))
gt_frames_flat12
=
jnp
.
reshape
(
gt_frames_flat12
,
aatype_in_shape
+
(
8
,
12
))
residx_rigidgroup_is_ambiguous
=
jnp
.
reshape
(
residx_rigidgroup_is_ambiguous
,
aatype_in_shape
+
(
8
,))
alt_gt_frames_flat12
=
jnp
.
reshape
(
alt_gt_frames_flat12
,
aatype_in_shape
+
(
8
,
12
,))
return
{
'rigidgroups_gt_frames'
:
gt_frames_flat12
,
# (..., 8, 12)
'rigidgroups_gt_exists'
:
gt_exists
,
# (..., 8)
'rigidgroups_group_exists'
:
group_exists
,
# (..., 8)
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
# (..., 8)
'rigidgroups_alt_gt_frames'
:
alt_gt_frames_flat12
,
# (..., 8, 12)
}
def
atom37_to_torsion_angles
(
aatype
:
jnp
.
ndarray
,
# (B, N)
all_atom_pos
:
jnp
.
ndarray
,
# (B, N, 37, 3)
all_atom_mask
:
jnp
.
ndarray
,
# (B, N, 37)
placeholder_for_undefined
=
False
,
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Computes the 7 torsion angles (in sin, cos encoding) for each residue.
The 7 torsion angles are in the order
'[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
here pre_omega denotes the omega torsion angle between the given amino acid
and the previous amino acid.
Args:
aatype: Amino acid type, given as array with integers.
all_atom_pos: atom37 representation of all atom coordinates.
all_atom_mask: atom37 representation of mask on all atom coordinates.
placeholder_for_undefined: flag denoting whether to set masked torsion
angles to zero.
Returns:
Dict containing:
* 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
2 dimensions denote sin and cos respectively
* 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
with the angle shifted by pi for all chi angles affected by the naming
ambiguities.
* 'torsion_angles_mask': Mask for which chi angles are present.
"""
# Map aatype > 20 to 'Unknown' (20).
aatype
=
jnp
.
minimum
(
aatype
,
20
)
# Compute the backbone angles.
num_batch
,
num_res
=
aatype
.
shape
pad
=
jnp
.
zeros
([
num_batch
,
1
,
37
,
3
],
jnp
.
float32
)
prev_all_atom_pos
=
jnp
.
concatenate
([
pad
,
all_atom_pos
[:,
:
-
1
,
:,
:]],
axis
=
1
)
pad
=
jnp
.
zeros
([
num_batch
,
1
,
37
],
jnp
.
float32
)
prev_all_atom_mask
=
jnp
.
concatenate
([
pad
,
all_atom_mask
[:,
:
-
1
,
:]],
axis
=
1
)
# For each torsion angle collect the 4 atom positions that define this angle.
# shape (B, N, atoms=4, xyz=3)
pre_omega_atom_pos
=
jnp
.
concatenate
(
[
prev_all_atom_pos
[:,
:,
1
:
3
,
:],
# prev CA, C
all_atom_pos
[:,
:,
0
:
2
,
:]
# this N, CA
],
axis
=-
2
)
phi_atom_pos
=
jnp
.
concatenate
(
[
prev_all_atom_pos
[:,
:,
2
:
3
,
:],
# prev C
all_atom_pos
[:,
:,
0
:
3
,
:]
# this N, CA, C
],
axis
=-
2
)
psi_atom_pos
=
jnp
.
concatenate
(
[
all_atom_pos
[:,
:,
0
:
3
,
:],
# this N, CA, C
all_atom_pos
[:,
:,
4
:
5
,
:]
# this O
],
axis
=-
2
)
# Collect the masks from these atoms.
# Shape [batch, num_res]
pre_omega_mask
=
(
jnp
.
prod
(
prev_all_atom_mask
[:,
:,
1
:
3
],
axis
=-
1
)
# prev CA, C
*
jnp
.
prod
(
all_atom_mask
[:,
:,
0
:
2
],
axis
=-
1
))
# this N, CA
phi_mask
=
(
prev_all_atom_mask
[:,
:,
2
]
# prev C
*
jnp
.
prod
(
all_atom_mask
[:,
:,
0
:
3
],
axis
=-
1
))
# this N, CA, C
psi_mask
=
(
jnp
.
prod
(
all_atom_mask
[:,
:,
0
:
3
],
axis
=-
1
)
*
# this N, CA, C
all_atom_mask
[:,
:,
4
])
# this O
# Collect the atoms for the chi-angles.
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices
=
get_chi_atom_indices
()
# Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
atom_indices
=
utils
.
batched_gather
(
params
=
chi_atom_indices
,
indices
=
aatype
,
axis
=
0
,
batch_dims
=
0
)
# Gather atom positions. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].
chis_atom_pos
=
utils
.
batched_gather
(
params
=
all_atom_pos
,
indices
=
atom_indices
,
axis
=-
2
,
batch_dims
=
2
)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask
=
list
(
residue_constants
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.0
,
0.0
,
0.0
,
0.0
])
chi_angles_mask
=
jnp
.
asarray
(
chi_angles_mask
)
# Compute the chi angle mask. I.e. which chis angles exist according to the
# aatype. Shape [batch, num_res, chis=4].
chis_mask
=
utils
.
batched_gather
(
params
=
chi_angles_mask
,
indices
=
aatype
,
axis
=
0
,
batch_dims
=
0
)
# Constrain the chis_mask to those chis, where the ground truth coordinates of
# all defining four atoms are available.
# Gather the chi angle atoms mask. Shape: [batch, num_res, chis=4, atoms=4].
chi_angle_atoms_mask
=
utils
.
batched_gather
(
params
=
all_atom_mask
,
indices
=
atom_indices
,
axis
=-
1
,
batch_dims
=
2
)
# Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
chi_angle_atoms_mask
=
jnp
.
prod
(
chi_angle_atoms_mask
,
axis
=
[
-
1
])
chis_mask
=
chis_mask
*
(
chi_angle_atoms_mask
).
astype
(
jnp
.
float32
)
# Stack all torsion angle atom positions.
# Shape (B, N, torsions=7, atoms=4, xyz=3)
torsions_atom_pos
=
jnp
.
concatenate
(
[
pre_omega_atom_pos
[:,
:,
None
,
:,
:],
phi_atom_pos
[:,
:,
None
,
:,
:],
psi_atom_pos
[:,
:,
None
,
:,
:],
chis_atom_pos
],
axis
=
2
)
# Stack up masks for all torsion angles.
# shape (B, N, torsions=7)
torsion_angles_mask
=
jnp
.
concatenate
(
[
pre_omega_mask
[:,
:,
None
],
phi_mask
[:,
:,
None
],
psi_mask
[:,
:,
None
],
chis_mask
],
axis
=
2
)
# Create a frame from the first three atoms:
# First atom: point on x-y-plane
# Second atom: point on negative x-axis
# Third atom: origin
# r3.Rigids (B, N, torsions=7)
torsion_frames
=
r3
.
rigids_from_3_points
(
point_on_neg_x_axis
=
r3
.
vecs_from_tensor
(
torsions_atom_pos
[:,
:,
:,
1
,
:]),
origin
=
r3
.
vecs_from_tensor
(
torsions_atom_pos
[:,
:,
:,
2
,
:]),
point_on_xy_plane
=
r3
.
vecs_from_tensor
(
torsions_atom_pos
[:,
:,
:,
0
,
:]))
# Compute the position of the forth atom in this frame (y and z coordinate
# define the chi angle)
# r3.Vecs (B, N, torsions=7)
forth_atom_rel_pos
=
r3
.
rigids_mul_vecs
(
r3
.
invert_rigids
(
torsion_frames
),
r3
.
vecs_from_tensor
(
torsions_atom_pos
[:,
:,
:,
3
,
:]))
# Normalize to have the sin and cos of the torsion angle.
# jnp.ndarray (B, N, torsions=7, sincos=2)
torsion_angles_sin_cos
=
jnp
.
stack
(
[
forth_atom_rel_pos
.
z
,
forth_atom_rel_pos
.
y
],
axis
=-
1
)
torsion_angles_sin_cos
/=
jnp
.
sqrt
(
jnp
.
sum
(
jnp
.
square
(
torsion_angles_sin_cos
),
axis
=-
1
,
keepdims
=
True
)
+
1e-8
)
# Mirror psi, because we computed it from the Oxygen-atom.
torsion_angles_sin_cos
*=
jnp
.
asarray
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
])[
None
,
None
,
:,
None
]
# Create alternative angles for ambiguous atom names.
chi_is_ambiguous
=
utils
.
batched_gather
(
jnp
.
asarray
(
residue_constants
.
chi_pi_periodic
),
aatype
)
mirror_torsion_angles
=
jnp
.
concatenate
(
[
jnp
.
ones
([
num_batch
,
num_res
,
3
]),
1.0
-
2.0
*
chi_is_ambiguous
],
axis
=-
1
)
alt_torsion_angles_sin_cos
=
(
torsion_angles_sin_cos
*
mirror_torsion_angles
[:,
:,
:,
None
])
if
placeholder_for_undefined
:
# Add placeholder torsions in place of undefined torsion angles
# (e.g. N-terminus pre-omega)
placeholder_torsions
=
jnp
.
stack
([
jnp
.
ones
(
torsion_angles_sin_cos
.
shape
[:
-
1
]),
jnp
.
zeros
(
torsion_angles_sin_cos
.
shape
[:
-
1
])
],
axis
=-
1
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
torsion_angles_mask
[
...,
None
]
+
placeholder_torsions
*
(
1
-
torsion_angles_mask
[...,
None
])
alt_torsion_angles_sin_cos
=
alt_torsion_angles_sin_cos
*
torsion_angles_mask
[
...,
None
]
+
placeholder_torsions
*
(
1
-
torsion_angles_mask
[...,
None
])
return
{
'torsion_angles_sin_cos'
:
torsion_angles_sin_cos
,
# (B, N, 7, 2)
'alt_torsion_angles_sin_cos'
:
alt_torsion_angles_sin_cos
,
# (B, N, 7, 2)
'torsion_angles_mask'
:
torsion_angles_mask
# (B, N, 7)
}
def
torsion_angles_to_frames
(
aatype
:
jnp
.
ndarray
,
# (N)
backb_to_global
:
r3
.
Rigids
,
# (N)
torsion_angles_sin_cos
:
jnp
.
ndarray
# (N, 7, 2)
)
->
r3
.
Rigids
:
# (N, 8)
"""Compute rigid group frames from torsion angles.
Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" lines 2-10
Jumper et al. (2021) Suppl. Alg. 25 "makeRotX"
Args:
aatype: aatype for each residue
backb_to_global: Rigid transformations describing transformation from
backbone frame to global frame.
torsion_angles_sin_cos: sin and cosine of the 7 torsion angles
Returns:
Frames corresponding to all the Sidechain Rigid Transforms
"""
assert
len
(
aatype
.
shape
)
==
1
assert
len
(
backb_to_global
.
rot
.
xx
.
shape
)
==
1
assert
len
(
torsion_angles_sin_cos
.
shape
)
==
3
assert
torsion_angles_sin_cos
.
shape
[
1
]
==
7
assert
torsion_angles_sin_cos
.
shape
[
2
]
==
2
# Gather the default frames for all rigid groups.
# r3.Rigids with shape (N, 8)
m
=
utils
.
batched_gather
(
residue_constants
.
restype_rigid_group_default_frame
,
aatype
)
default_frames
=
r3
.
rigids_from_tensor4x4
(
m
)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles
=
torsion_angles_sin_cos
[...,
0
]
cos_angles
=
torsion_angles_sin_cos
[...,
1
]
# insert zero rotation for backbone group.
num_residues
,
=
aatype
.
shape
sin_angles
=
jnp
.
concatenate
([
jnp
.
zeros
([
num_residues
,
1
]),
sin_angles
],
axis
=-
1
)
cos_angles
=
jnp
.
concatenate
([
jnp
.
ones
([
num_residues
,
1
]),
cos_angles
],
axis
=-
1
)
zeros
=
jnp
.
zeros_like
(
sin_angles
)
ones
=
jnp
.
ones_like
(
sin_angles
)
# all_rots are r3.Rots with shape (N, 8)
all_rots
=
r3
.
Rots
(
ones
,
zeros
,
zeros
,
zeros
,
cos_angles
,
-
sin_angles
,
zeros
,
sin_angles
,
cos_angles
)
# Apply rotations to the frames.
all_frames
=
r3
.
rigids_mul_rots
(
default_frames
,
all_rots
)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi2_frame_to_frame
=
jax
.
tree_map
(
lambda
x
:
x
[:,
5
],
all_frames
)
chi3_frame_to_frame
=
jax
.
tree_map
(
lambda
x
:
x
[:,
6
],
all_frames
)
chi4_frame_to_frame
=
jax
.
tree_map
(
lambda
x
:
x
[:,
7
],
all_frames
)
chi1_frame_to_backb
=
jax
.
tree_map
(
lambda
x
:
x
[:,
4
],
all_frames
)
chi2_frame_to_backb
=
r3
.
rigids_mul_rigids
(
chi1_frame_to_backb
,
chi2_frame_to_frame
)
chi3_frame_to_backb
=
r3
.
rigids_mul_rigids
(
chi2_frame_to_backb
,
chi3_frame_to_frame
)
chi4_frame_to_backb
=
r3
.
rigids_mul_rigids
(
chi3_frame_to_backb
,
chi4_frame_to_frame
)
# Recombine them to a r3.Rigids with shape (N, 8).
def
_concat_frames
(
xall
,
x5
,
x6
,
x7
):
return
jnp
.
concatenate
(
[
xall
[:,
0
:
5
],
x5
[:,
None
],
x6
[:,
None
],
x7
[:,
None
]],
axis
=-
1
)
all_frames_to_backb
=
jax
.
tree_map
(
_concat_frames
,
all_frames
,
chi2_frame_to_backb
,
chi3_frame_to_backb
,
chi4_frame_to_backb
)
# Create the global frames.
# shape (N, 8)
all_frames_to_global
=
r3
.
rigids_mul_rigids
(
jax
.
tree_map
(
lambda
x
:
x
[:,
None
],
backb_to_global
),
all_frames_to_backb
)
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
aatype
:
jnp
.
ndarray
,
# (N)
all_frames_to_global
:
r3
.
Rigids
# (N, 8)
)
->
r3
.
Vecs
:
# (N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group.
Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11
Args:
aatype: aatype for each residue.
all_frames_to_global: All per residue coordinate frames.
Returns:
Positions of all atom coordinates in global frame.
"""
# Pick the appropriate transform for every atom.
residx_to_group_idx
=
utils
.
batched_gather
(
residue_constants
.
restype_atom14_to_rigid_group
,
aatype
)
group_mask
=
jax
.
nn
.
one_hot
(
residx_to_group_idx
,
num_classes
=
8
)
# shape (N, 14, 8)
# r3.Rigids with shape (N, 14)
map_atoms_to_global
=
jax
.
tree_map
(
lambda
x
:
jnp
.
sum
(
x
[:,
None
,
:]
*
group_mask
,
axis
=-
1
),
all_frames_to_global
)
# Gather the literature atom positions for each residue.
# r3.Vecs with shape (N, 14)
lit_positions
=
r3
.
vecs_from_tensor
(
utils
.
batched_gather
(
residue_constants
.
restype_atom14_rigid_group_positions
,
aatype
))
# Transform each atom from its local frame to the global frame.
# r3.Vecs with shape (N, 14)
pred_positions
=
r3
.
rigids_mul_vecs
(
map_atoms_to_global
,
lit_positions
)
# Mask out non-existing atoms.
mask
=
utils
.
batched_gather
(
residue_constants
.
restype_atom14_mask
,
aatype
)
pred_positions
=
jax
.
tree_map
(
lambda
x
:
x
*
mask
,
pred_positions
)
return
pred_positions
def
extreme_ca_ca_distance_violations
(
pred_atom_positions
:
jnp
.
ndarray
,
# (N, 37(14), 3)
pred_atom_mask
:
jnp
.
ndarray
,
# (N, 37(14))
residue_index
:
jnp
.
ndarray
,
# (N)
max_angstrom_tolerance
=
1.5
)
->
jnp
.
ndarray
:
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
max_angstrom_tolerance: Maximum distance allowed to not count as violation.
Returns:
Fraction of consecutive CA-CA pairs with violation.
"""
this_ca_pos
=
pred_atom_positions
[:
-
1
,
1
,
:]
# (N - 1, 3)
this_ca_mask
=
pred_atom_mask
[:
-
1
,
1
]
# (N - 1)
next_ca_pos
=
pred_atom_positions
[
1
:,
1
,
:]
# (N - 1, 3)
next_ca_mask
=
pred_atom_mask
[
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
((
residue_index
[
1
:]
-
residue_index
[:
-
1
])
==
1.0
).
astype
(
jnp
.
float32
)
ca_ca_distance
=
jnp
.
sqrt
(
1e-6
+
jnp
.
sum
(
squared_difference
(
this_ca_pos
,
next_ca_pos
),
axis
=-
1
))
violations
=
(
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
return
utils
.
mask_mean
(
mask
=
mask
,
value
=
violations
)
def
between_residue_bond_loss
(
pred_atom_positions
:
jnp
.
ndarray
,
# (N, 37(14), 3)
pred_atom_mask
:
jnp
.
ndarray
,
# (N, 37(14))
residue_index
:
jnp
.
ndarray
,
# (N)
aatype
:
jnp
.
ndarray
,
# (N)
tolerance_factor_soft
=
12.0
,
tolerance_factor_hard
=
12.0
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
aatype: Amino acid type of given residue
tolerance_factor_soft: soft tolerance factor measured in standard deviations
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
by CA, C, N
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
by C, N, CA
* 'per_residue_loss_sum': sum of all losses for each residue
* 'per_residue_violation_mask': mask denoting all residues with violation
present.
"""
assert
len
(
pred_atom_positions
.
shape
)
==
3
assert
len
(
pred_atom_mask
.
shape
)
==
2
assert
len
(
residue_index
.
shape
)
==
1
assert
len
(
aatype
.
shape
)
==
1
# Get the positions of the relevant backbone atoms.
this_ca_pos
=
pred_atom_positions
[:
-
1
,
1
,
:]
# (N - 1, 3)
this_ca_mask
=
pred_atom_mask
[:
-
1
,
1
]
# (N - 1)
this_c_pos
=
pred_atom_positions
[:
-
1
,
2
,
:]
# (N - 1, 3)
this_c_mask
=
pred_atom_mask
[:
-
1
,
2
]
# (N - 1)
next_n_pos
=
pred_atom_positions
[
1
:,
0
,
:]
# (N - 1, 3)
next_n_mask
=
pred_atom_mask
[
1
:,
0
]
# (N - 1)
next_ca_pos
=
pred_atom_positions
[
1
:,
1
,
:]
# (N - 1, 3)
next_ca_mask
=
pred_atom_mask
[
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
((
residue_index
[
1
:]
-
residue_index
[:
-
1
])
==
1.0
).
astype
(
jnp
.
float32
)
# Compute loss for the C--N bond.
c_n_bond_length
=
jnp
.
sqrt
(
1e-6
+
jnp
.
sum
(
squared_difference
(
this_c_pos
,
next_n_pos
),
axis
=-
1
))
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
(
aatype
[
1
:]
==
residue_constants
.
resname_to_idx
[
'PRO'
]).
astype
(
jnp
.
float32
)
gt_length
=
(
(
1.
-
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
])
gt_stddev
=
(
(
1.
-
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
])
c_n_bond_length_error
=
jnp
.
sqrt
(
1e-6
+
jnp
.
square
(
c_n_bond_length
-
gt_length
))
c_n_loss_per_residue
=
jax
.
nn
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
c_n_loss
=
jnp
.
sum
(
mask
*
c_n_loss_per_residue
)
/
(
jnp
.
sum
(
mask
)
+
1e-6
)
c_n_violation_mask
=
mask
*
(
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
))
# Compute loss for the angles.
ca_c_bond_length
=
jnp
.
sqrt
(
1e-6
+
jnp
.
sum
(
squared_difference
(
this_ca_pos
,
this_c_pos
),
axis
=-
1
))
n_ca_bond_length
=
jnp
.
sqrt
(
1e-6
+
jnp
.
sum
(
squared_difference
(
next_n_pos
,
next_ca_pos
),
axis
=-
1
))
c_ca_unit_vec
=
(
this_ca_pos
-
this_c_pos
)
/
ca_c_bond_length
[:,
None
]
c_n_unit_vec
=
(
next_n_pos
-
this_c_pos
)
/
c_n_bond_length
[:,
None
]
n_ca_unit_vec
=
(
next_ca_pos
-
next_n_pos
)
/
n_ca_bond_length
[:,
None
]
ca_c_n_cos_angle
=
jnp
.
sum
(
c_ca_unit_vec
*
c_n_unit_vec
,
axis
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_ca_c_n
[
0
]
gt_stddev
=
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
ca_c_n_cos_angle_error
=
jnp
.
sqrt
(
1e-6
+
jnp
.
square
(
ca_c_n_cos_angle
-
gt_angle
))
ca_c_n_loss_per_residue
=
jax
.
nn
.
relu
(
ca_c_n_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
ca_c_n_loss
=
jnp
.
sum
(
mask
*
ca_c_n_loss_per_residue
)
/
(
jnp
.
sum
(
mask
)
+
1e-6
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
c_n_ca_cos_angle
=
jnp
.
sum
((
-
c_n_unit_vec
)
*
n_ca_unit_vec
,
axis
=-
1
)
gt_angle
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
0
]
gt_stddev
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
1
]
c_n_ca_cos_angle_error
=
jnp
.
sqrt
(
1e-6
+
jnp
.
square
(
c_n_ca_cos_angle
-
gt_angle
))
c_n_ca_loss_per_residue
=
jax
.
nn
.
relu
(
c_n_ca_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
c_n_ca_loss
=
jnp
.
sum
(
mask
*
c_n_ca_loss_per_residue
)
/
(
jnp
.
sum
(
mask
)
+
1e-6
)
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
)
per_residue_loss_sum
=
0.5
*
(
jnp
.
pad
(
per_residue_loss_sum
,
[[
0
,
1
]])
+
jnp
.
pad
(
per_residue_loss_sum
,
[[
1
,
0
]]))
# Compute hard violations.
violation_mask
=
jnp
.
max
(
jnp
.
stack
([
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
]),
axis
=
0
)
violation_mask
=
jnp
.
maximum
(
jnp
.
pad
(
violation_mask
,
[[
0
,
1
]]),
jnp
.
pad
(
violation_mask
,
[[
1
,
0
]]))
return
{
'c_n_loss_mean'
:
c_n_loss
,
# shape ()
'ca_c_n_loss_mean'
:
ca_c_n_loss
,
# shape ()
'c_n_ca_loss_mean'
:
c_n_ca_loss
,
# shape ()
'per_residue_loss_sum'
:
per_residue_loss_sum
,
# shape (N)
'per_residue_violation_mask'
:
violation_mask
# shape (N)
}
def
between_residue_clash_loss
(
atom14_pred_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
atom14_atom_exists
:
jnp
.
ndarray
,
# (N, 14)
atom14_atom_radius
:
jnp
.
ndarray
,
# (N, 14)
residue_index
:
jnp
.
ndarray
,
# (N)
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_hard
=
1.5
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_atom_radius: Van der Waals radius for each atom.
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
assert
len
(
atom14_pred_positions
.
shape
)
==
3
assert
len
(
atom14_atom_exists
.
shape
)
==
2
assert
len
(
atom14_atom_radius
.
shape
)
==
2
assert
len
(
residue_index
.
shape
)
==
1
# Create the distance matrix.
# (N, N, 14, 14)
dists
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
squared_difference
(
atom14_pred_positions
[:,
None
,
:,
None
,
:],
atom14_pred_positions
[
None
,
:,
None
,
:,
:]),
axis
=-
1
))
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask
=
(
atom14_atom_exists
[:,
None
,
:,
None
]
*
atom14_atom_exists
[
None
,
:,
None
,
:])
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask
*=
(
residue_index
[:,
None
,
None
,
None
]
<
residue_index
[
None
,
:,
None
,
None
])
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
jax
.
nn
.
one_hot
(
2
,
num_classes
=
14
)
n_one_hot
=
jax
.
nn
.
one_hot
(
0
,
num_classes
=
14
)
neighbour_mask
=
((
residue_index
[:,
None
,
None
,
None
]
+
1
)
==
residue_index
[
None
,
:,
None
,
None
])
c_n_bonds
=
neighbour_mask
*
c_one_hot
[
None
,
None
,
:,
None
]
*
n_one_hot
[
None
,
None
,
None
,
:]
dists_mask
*=
(
1.
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
cys_sg_idx
=
residue_constants
.
restype_name_to_atom14_names
[
'CYS'
].
index
(
'SG'
)
cys_sg_one_hot
=
jax
.
nn
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
disulfide_bonds
=
(
cys_sg_one_hot
[
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[
None
,
None
,
None
,
:])
dists_mask
*=
(
1.
-
disulfide_bonds
)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound
=
dists_mask
*
(
atom14_atom_radius
[:,
None
,
:,
None
]
+
atom14_atom_radius
[
None
,
:,
None
,
:])
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error
=
dists_mask
*
jax
.
nn
.
relu
(
dists_lower_bound
-
overlap_tolerance_soft
-
dists
)
# Compute the mean loss.
# shape ()
mean_loss
=
(
jnp
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
jnp
.
sum
(
dists_mask
)))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
(
jnp
.
sum
(
dists_to_low_error
,
axis
=
[
0
,
2
])
+
jnp
.
sum
(
dists_to_low_error
,
axis
=
[
1
,
3
]))
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask
=
jnp
.
maximum
(
jnp
.
max
(
clash_mask
,
axis
=
[
0
,
2
]),
jnp
.
max
(
clash_mask
,
axis
=
[
1
,
3
]))
return
{
'mean_loss'
:
mean_loss
,
# shape ()
'per_atom_loss_sum'
:
per_atom_loss_sum
,
# shape (N, 14)
'per_atom_clash_mask'
:
per_atom_clash_mask
# shape (N, 14)
}
def
within_residue_violations
(
atom14_pred_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
atom14_atom_exists
:
jnp
.
ndarray
,
# (N, 14)
atom14_dists_lower_bound
:
jnp
.
ndarray
,
# (N, 14, 14)
atom14_dists_upper_bound
:
jnp
.
ndarray
,
# (N, 14, 14)
tighten_bounds_for_loss
=
0.0
,
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound: Lower bound on allowed distances.
atom14_dists_upper_bound: Upper bound on allowed distances
tighten_bounds_for_loss: Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
assert
len
(
atom14_pred_positions
.
shape
)
==
3
assert
len
(
atom14_atom_exists
.
shape
)
==
2
assert
len
(
atom14_dists_lower_bound
.
shape
)
==
3
assert
len
(
atom14_dists_upper_bound
.
shape
)
==
3
# Compute the mask for each residue.
# shape (N, 14, 14)
dists_masks
=
(
1.
-
jnp
.
eye
(
14
,
14
)[
None
])
dists_masks
*=
(
atom14_atom_exists
[:,
:,
None
]
*
atom14_atom_exists
[:,
None
,
:])
# Distance matrix
# shape (N, 14, 14)
dists
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
squared_difference
(
atom14_pred_positions
[:,
:,
None
,
:],
atom14_pred_positions
[:,
None
,
:,
:]),
axis
=-
1
))
# Compute the loss.
# shape (N, 14, 14)
dists_to_low_error
=
jax
.
nn
.
relu
(
atom14_dists_lower_bound
+
tighten_bounds_for_loss
-
dists
)
dists_to_high_error
=
jax
.
nn
.
relu
(
dists
-
(
atom14_dists_upper_bound
-
tighten_bounds_for_loss
))
loss
=
dists_masks
*
(
dists_to_low_error
+
dists_to_high_error
)
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
(
jnp
.
sum
(
loss
,
axis
=
1
)
+
jnp
.
sum
(
loss
,
axis
=
2
))
# Compute the violations mask.
# shape (N, 14, 14)
violations
=
dists_masks
*
((
dists
<
atom14_dists_lower_bound
)
|
(
dists
>
atom14_dists_upper_bound
))
# Compute the per atom violations.
# shape (N, 14)
per_atom_violations
=
jnp
.
maximum
(
jnp
.
max
(
violations
,
axis
=
1
),
jnp
.
max
(
violations
,
axis
=
2
))
return
{
'per_atom_loss_sum'
:
per_atom_loss_sum
,
# shape (N, 14)
'per_atom_violations'
:
per_atom_violations
# shape (N, 14)
}
def
find_optimal_renaming
(
atom14_gt_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
atom14_alt_gt_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
atom14_atom_is_ambiguous
:
jnp
.
ndarray
,
# (N, 14)
atom14_gt_exists
:
jnp
.
ndarray
,
# (N, 14)
atom14_pred_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
atom14_atom_exists
:
jnp
.
ndarray
,
# (N, 14)
)
->
jnp
.
ndarray
:
# (N):
"""Find optimal renaming for ground truth that maximizes LDDT.
Jumper et al. (2021) Suppl. Alg. 26
"renameSymmetricGroundTruthAtoms" lines 1-5
Args:
atom14_gt_positions: Ground truth positions in global frame of ground truth.
atom14_alt_gt_positions: Alternate ground truth positions in global frame of
ground truth with coordinates of ambiguous atoms swapped relative to
'atom14_gt_positions'.
atom14_atom_is_ambiguous: Mask denoting whether atom is among ambiguous
atoms, see Jumper et al. (2021) Suppl. Table 3
atom14_gt_exists: Mask denoting whether atom at positions exists in ground
truth.
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
Returns:
Float array of shape [N] with 1. where atom14_alt_gt_positions is closer to
prediction and 0. otherwise
"""
assert
len
(
atom14_gt_positions
.
shape
)
==
3
assert
len
(
atom14_alt_gt_positions
.
shape
)
==
3
assert
len
(
atom14_atom_is_ambiguous
.
shape
)
==
2
assert
len
(
atom14_gt_exists
.
shape
)
==
2
assert
len
(
atom14_pred_positions
.
shape
)
==
3
assert
len
(
atom14_atom_exists
.
shape
)
==
2
# Create the pred distance matrix.
# shape (N, N, 14, 14)
pred_dists
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
squared_difference
(
atom14_pred_positions
[:,
None
,
:,
None
,
:],
atom14_pred_positions
[
None
,
:,
None
,
:,
:]),
axis
=-
1
))
# Compute distances for ground truth with original and alternative names.
# shape (N, N, 14, 14)
gt_dists
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
squared_difference
(
atom14_gt_positions
[:,
None
,
:,
None
,
:],
atom14_gt_positions
[
None
,
:,
None
,
:,
:]),
axis
=-
1
))
alt_gt_dists
=
jnp
.
sqrt
(
1e-10
+
jnp
.
sum
(
squared_difference
(
atom14_alt_gt_positions
[:,
None
,
:,
None
,
:],
atom14_alt_gt_positions
[
None
,
:,
None
,
:,
:]),
axis
=-
1
))
# Compute LDDT's.
# shape (N, N, 14, 14)
lddt
=
jnp
.
sqrt
(
1e-10
+
squared_difference
(
pred_dists
,
gt_dists
))
alt_lddt
=
jnp
.
sqrt
(
1e-10
+
squared_difference
(
pred_dists
,
alt_gt_dists
))
# Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms
# in cols.
# shape (N ,N, 14, 14)
mask
=
(
atom14_gt_exists
[:,
None
,
:,
None
]
*
# rows
atom14_atom_is_ambiguous
[:,
None
,
:,
None
]
*
# rows
atom14_gt_exists
[
None
,
:,
None
,
:]
*
# cols
(
1.
-
atom14_atom_is_ambiguous
[
None
,
:,
None
,
:]))
# cols
# Aggregate distances for each residue to the non-amibuguous atoms.
# shape (N)
per_res_lddt
=
jnp
.
sum
(
mask
*
lddt
,
axis
=
[
1
,
2
,
3
])
alt_per_res_lddt
=
jnp
.
sum
(
mask
*
alt_lddt
,
axis
=
[
1
,
2
,
3
])
# Decide for each residue, whether alternative naming is better.
# shape (N)
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
astype
(
jnp
.
float32
)
return
alt_naming_is_better
# shape (N)
def
frame_aligned_point_error
(
pred_frames
:
r3
.
Rigids
,
# shape (num_frames)
target_frames
:
r3
.
Rigids
,
# shape (num_frames)
frames_mask
:
jnp
.
ndarray
,
# shape (num_frames)
pred_positions
:
r3
.
Vecs
,
# shape (num_positions)
target_positions
:
r3
.
Vecs
,
# shape (num_positions)
positions_mask
:
jnp
.
ndarray
,
# shape (num_positions)
length_scale
:
float
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
epsilon
=
1e-4
)
->
jnp
.
ndarray
:
# shape ()
"""Measure point error under different alignments.
Jumper et al. (2021) Suppl. Alg. 28 "computeFAPE"
Computes error between two structures with B points under A alignments derived
from the given pairs of frames.
Args:
pred_frames: num_frames reference frames for 'pred_positions'.
target_frames: num_frames reference frames for 'target_positions'.
frames_mask: Mask for frame pairs to use.
pred_positions: num_positions predicted positions of the structure.
target_positions: num_positions target positions of the structure.
positions_mask: Mask on which positions to score.
length_scale: length scale to divide loss by.
l1_clamp_distance: Distance cutoff on error beyond which gradients will
be zero.
epsilon: small value used to regularize denominator for masked average.
Returns:
Masked Frame Aligned Point Error.
"""
assert
pred_frames
.
rot
.
xx
.
ndim
==
1
assert
target_frames
.
rot
.
xx
.
ndim
==
1
assert
frames_mask
.
ndim
==
1
,
frames_mask
.
ndim
assert
pred_positions
.
x
.
ndim
==
1
assert
target_positions
.
x
.
ndim
==
1
assert
positions_mask
.
ndim
==
1
# Compute array of predicted positions in the predicted frames.
# r3.Vecs (num_frames, num_positions)
local_pred_pos
=
r3
.
rigids_mul_vecs
(
jax
.
tree_map
(
lambda
r
:
r
[:,
None
],
r3
.
invert_rigids
(
pred_frames
)),
jax
.
tree_map
(
lambda
x
:
x
[
None
,
:],
pred_positions
))
# Compute array of target positions in the target frames.
# r3.Vecs (num_frames, num_positions)
local_target_pos
=
r3
.
rigids_mul_vecs
(
jax
.
tree_map
(
lambda
r
:
r
[:,
None
],
r3
.
invert_rigids
(
target_frames
)),
jax
.
tree_map
(
lambda
x
:
x
[
None
,
:],
target_positions
))
# Compute errors between the structures.
# jnp.ndarray (num_frames, num_positions)
error_dist
=
jnp
.
sqrt
(
r3
.
vecs_squared_distance
(
local_pred_pos
,
local_target_pos
)
+
epsilon
)
if
l1_clamp_distance
:
error_dist
=
jnp
.
clip
(
error_dist
,
0
,
l1_clamp_distance
)
normed_error
=
error_dist
/
length_scale
normed_error
*=
jnp
.
expand_dims
(
frames_mask
,
axis
=-
1
)
normed_error
*=
jnp
.
expand_dims
(
positions_mask
,
axis
=-
2
)
normalization_factor
=
(
jnp
.
sum
(
frames_mask
,
axis
=-
1
)
*
jnp
.
sum
(
positions_mask
,
axis
=-
1
))
return
(
jnp
.
sum
(
normed_error
,
axis
=
(
-
2
,
-
1
))
/
(
epsilon
+
normalization_factor
))
def
_make_renaming_matrices
():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3
=
[
residue_constants
.
restype_1to3
[
res
]
for
res
in
residue_constants
.
restypes
]
restype_3
+=
[
'UNK'
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
return
renaming_matrices
RENAMING_MATRICES
=
_make_renaming_matrices
()
def
get_alt_atom14
(
aatype
,
positions
,
mask
):
"""Get alternative atom14 positions.
Constructs renamed atom positions for ambiguous residues.
Jumper et al. (2021) Suppl. Table 3 "Ambiguous atom names due to 180 degree-
rotation-symmetry"
Args:
aatype: Amino acid at given position
positions: Atom positions as r3.Vecs in atom14 representation, (N, 14)
mask: Atom masks in atom14 representation, (N, 14)
Returns:
renamed atom positions, renamed atom mask
"""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform
=
utils
.
batched_gather
(
jnp
.
asarray
(
RENAMING_MATRICES
),
aatype
)
positions
=
jax
.
tree_map
(
lambda
x
:
x
[:,
:,
None
],
positions
)
alternative_positions
=
jax
.
tree_map
(
lambda
x
:
jnp
.
sum
(
x
,
axis
=
1
),
positions
*
renaming_transform
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask
=
jnp
.
sum
(
mask
[...,
None
]
*
renaming_transform
,
axis
=
1
)
return
alternative_positions
,
alternative_mask
alphafold/model/all_atom_multimer.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from
typing
import
Dict
,
Text
from
alphafold.common
import
residue_constants
from
alphafold.model
import
geometry
from
alphafold.model
import
utils
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
def
_make_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
residue_constants
.
restypes
:
residue_name
=
residue_constants
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
residue_constants
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
residue_constants
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
np
.
array
(
chi_atom_indices
)
def
_make_renaming_matrices
():
"""Matrices to map atoms to symmetry partners in ambiguous case."""
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative groundtruth coordinates where the naming is swapped
restype_3
=
[
residue_constants
.
restype_1to3
[
res
]
for
res
in
residue_constants
.
restypes
]
restype_3
+=
[
'UNK'
]
# Matrices for renaming ambiguous atoms.
all_matrices
=
{
res
:
np
.
eye
(
14
,
dtype
=
np
.
float32
)
for
res
in
restype_3
}
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
correspondences
=
np
.
arange
(
14
)
for
source_atom_swap
,
target_atom_swap
in
swap
.
items
():
source_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
source_atom_swap
)
target_index
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
target_atom_swap
)
correspondences
[
source_index
]
=
target_index
correspondences
[
target_index
]
=
source_index
renaming_matrix
=
np
.
zeros
((
14
,
14
),
dtype
=
np
.
float32
)
for
index
,
correspondence
in
enumerate
(
correspondences
):
renaming_matrix
[
index
,
correspondence
]
=
1.
all_matrices
[
resname
]
=
renaming_matrix
.
astype
(
np
.
float32
)
renaming_matrices
=
np
.
stack
([
all_matrices
[
restype
]
for
restype
in
restype_3
])
return
renaming_matrices
def
_make_restype_atom37_mask
():
"""Mask of which atoms are present for which residue type in atom37."""
# create the corresponding mask
restype_atom37_mask
=
np
.
zeros
([
21
,
37
],
dtype
=
np
.
float32
)
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
restype_name
=
residue_constants
.
restype_1to3
[
restype_letter
]
atom_names
=
residue_constants
.
residue_atoms
[
restype_name
]
for
atom_name
in
atom_names
:
atom_type
=
residue_constants
.
atom_order
[
atom_name
]
restype_atom37_mask
[
restype
,
atom_type
]
=
1
return
restype_atom37_mask
def
_make_restype_atom14_mask
():
"""Mask of which atoms are present for which residue type in atom14."""
restype_atom14_mask
=
[]
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
restype_atom14_mask
.
append
([(
1.
if
name
else
0.
)
for
name
in
atom_names
])
restype_atom14_mask
.
append
([
0.
]
*
14
)
restype_atom14_mask
=
np
.
array
(
restype_atom14_mask
,
dtype
=
np
.
float32
)
return
restype_atom14_mask
def
_make_restype_atom37_to_atom14
():
"""Map from atom37 to atom14 per residue type."""
restype_atom37_to_atom14
=
[]
# mapping (restype, atom37) --> atom14
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
atom_name_to_idx14
=
{
name
:
i
for
i
,
name
in
enumerate
(
atom_names
)}
restype_atom37_to_atom14
.
append
([
(
atom_name_to_idx14
[
name
]
if
name
in
atom_name_to_idx14
else
0
)
for
name
in
residue_constants
.
atom_types
])
restype_atom37_to_atom14
.
append
([
0
]
*
37
)
restype_atom37_to_atom14
=
np
.
array
(
restype_atom37_to_atom14
,
dtype
=
np
.
int32
)
return
restype_atom37_to_atom14
def
_make_restype_atom14_to_atom37
():
"""Map from atom14 to atom37 per residue type."""
restype_atom14_to_atom37
=
[]
# mapping (restype, atom14) --> atom37
for
rt
in
residue_constants
.
restypes
:
atom_names
=
residue_constants
.
restype_name_to_atom14_names
[
residue_constants
.
restype_1to3
[
rt
]]
restype_atom14_to_atom37
.
append
([
(
residue_constants
.
atom_order
[
name
]
if
name
else
0
)
for
name
in
atom_names
])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37
.
append
([
0
]
*
14
)
restype_atom14_to_atom37
=
np
.
array
(
restype_atom14_to_atom37
,
dtype
=
np
.
int32
)
return
restype_atom14_to_atom37
def
_make_restype_atom14_is_ambiguous
():
"""Mask which atoms are ambiguous in atom14."""
# create an ambiguous atoms mask. shape: (21, 14)
restype_atom14_is_ambiguous
=
np
.
zeros
((
21
,
14
),
dtype
=
np
.
float32
)
for
resname
,
swap
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
for
atom_name1
,
atom_name2
in
swap
.
items
():
restype
=
residue_constants
.
restype_order
[
residue_constants
.
restype_3to1
[
resname
]]
atom_idx1
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name1
)
atom_idx2
=
residue_constants
.
restype_name_to_atom14_names
[
resname
].
index
(
atom_name2
)
restype_atom14_is_ambiguous
[
restype
,
atom_idx1
]
=
1
restype_atom14_is_ambiguous
[
restype
,
atom_idx2
]
=
1
return
restype_atom14_is_ambiguous
def
_make_restype_rigidgroup_base_atom37_idx
():
"""Create Map from rigidgroups to atom37 indices."""
# Create an array with the atom names.
# shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3)
base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
# 0: backbone frame
base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
# 3: 'psi-group'
base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
# 4,5,6,7: 'chi1,2,3,4-group'
for
restype
,
restype_letter
in
enumerate
(
residue_constants
.
restypes
):
resname
=
residue_constants
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
residue_constants
.
chi_angles_mask
[
restype
][
chi_idx
]:
atom_names
=
residue_constants
.
chi_angles_atoms
[
resname
][
chi_idx
]
base_atom_names
[
restype
,
chi_idx
+
4
,
:]
=
atom_names
[
1
:]
# Translate atom names into atom37 indices.
lookuptable
=
residue_constants
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
restype_rigidgroup_base_atom37_idx
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])(
base_atom_names
)
return
restype_rigidgroup_base_atom37_idx
CHI_ATOM_INDICES
=
_make_chi_atom_indices
()
RENAMING_MATRICES
=
_make_renaming_matrices
()
RESTYPE_ATOM14_TO_ATOM37
=
_make_restype_atom14_to_atom37
()
RESTYPE_ATOM37_TO_ATOM14
=
_make_restype_atom37_to_atom14
()
RESTYPE_ATOM37_MASK
=
_make_restype_atom37_mask
()
RESTYPE_ATOM14_MASK
=
_make_restype_atom14_mask
()
RESTYPE_ATOM14_IS_AMBIGUOUS
=
_make_restype_atom14_is_ambiguous
()
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
=
_make_restype_rigidgroup_base_atom37_idx
()
# Create mask for existing rigid groups.
RESTYPE_RIGIDGROUP_MASK
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
RESTYPE_RIGIDGROUP_MASK
[:,
0
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:,
3
]
=
1
RESTYPE_RIGIDGROUP_MASK
[:
20
,
4
:]
=
residue_constants
.
chi_angles_mask
def
get_atom37_mask
(
aatype
):
return
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM37_MASK
),
aatype
)
def
get_atom14_mask
(
aatype
):
return
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM14_MASK
),
aatype
)
def
get_atom14_is_ambiguous
(
aatype
):
return
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM14_IS_AMBIGUOUS
),
aatype
)
def
get_atom14_to_atom37_map
(
aatype
):
return
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM14_TO_ATOM37
),
aatype
)
def
get_atom37_to_atom14_map
(
aatype
):
return
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM37_TO_ATOM14
),
aatype
)
def
atom14_to_atom37
(
atom14_data
:
jnp
.
ndarray
,
# (N, 14, ...)
aatype
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
# (N, 37, ...)
"""Convert atom14 to atom37 representation."""
assert
len
(
atom14_data
.
shape
)
in
[
2
,
3
]
idx_atom37_to_atom14
=
get_atom37_to_atom14_map
(
aatype
)
atom37_data
=
utils
.
batched_gather
(
atom14_data
,
idx_atom37_to_atom14
,
batch_dims
=
1
)
atom37_mask
=
get_atom37_mask
(
aatype
)
if
len
(
atom14_data
.
shape
)
==
2
:
atom37_data
*=
atom37_mask
elif
len
(
atom14_data
.
shape
)
==
3
:
atom37_data
*=
atom37_mask
[:,
:,
None
].
astype
(
atom37_data
.
dtype
)
return
atom37_data
def
atom37_to_atom14
(
aatype
,
all_atom_pos
,
all_atom_mask
):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37
=
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM14_TO_ATOM37
),
aatype
)
atom14_mask
=
utils
.
batched_gather
(
all_atom_mask
,
residx_atom14_to_atom37
,
batch_dims
=
1
).
astype
(
jnp
.
float32
)
# create a mask for known groundtruth positions
atom14_mask
*=
utils
.
batched_gather
(
jnp
.
asarray
(
RESTYPE_ATOM14_MASK
),
aatype
)
# gather the groundtruth positions
atom14_positions
=
jax
.
tree_map
(
lambda
x
:
utils
.
batched_gather
(
x
,
residx_atom14_to_atom37
,
batch_dims
=
1
),
all_atom_pos
)
atom14_positions
=
atom14_mask
*
atom14_positions
return
atom14_positions
,
atom14_mask
def
get_alt_atom14
(
aatype
,
positions
:
geometry
.
Vec3Array
,
mask
):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform
=
utils
.
batched_gather
(
jnp
.
asarray
(
RENAMING_MATRICES
),
aatype
)
alternative_positions
=
jax
.
tree_map
(
lambda
x
:
jnp
.
sum
(
x
,
axis
=
1
),
positions
[:,
:,
None
]
*
renaming_transform
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask
=
jnp
.
sum
(
mask
[...,
None
]
*
renaming_transform
,
axis
=
1
)
return
alternative_positions
,
alternative_mask
def
atom37_to_frames
(
aatype
:
jnp
.
ndarray
,
# (...)
all_atom_positions
:
geometry
.
Vec3Array
,
# (..., 37)
all_atom_mask
:
jnp
.
ndarray
,
# (..., 37)
)
->
Dict
[
Text
,
jnp
.
ndarray
]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
aatype_in_shape
=
aatype
.
shape
# If there is a batch axis, just flatten it away, and reshape everything
# back at the end of the function.
aatype
=
jnp
.
reshape
(
aatype
,
[
-
1
])
all_atom_positions
=
jax
.
tree_map
(
lambda
x
:
jnp
.
reshape
(
x
,
[
-
1
,
37
]),
all_atom_positions
)
all_atom_mask
=
jnp
.
reshape
(
all_atom_mask
,
[
-
1
,
37
])
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx
=
utils
.
batched_gather
(
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
,
aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos
=
jax
.
tree_map
(
lambda
x
:
utils
.
batched_gather
(
# pylint: disable=g-long-lambda
x
,
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
1
),
all_atom_positions
)
# Compute the Rigids.
point_on_neg_x_axis
=
base_atom_pos
[:,
:,
0
]
origin
=
base_atom_pos
[:,
:,
1
]
point_on_xy_plane
=
base_atom_pos
[:,
:,
2
]
gt_rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
origin
-
point_on_neg_x_axis
,
point_on_xy_plane
-
origin
)
gt_frames
=
geometry
.
Rigid3Array
(
gt_rotation
,
origin
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists
=
utils
.
batched_gather
(
RESTYPE_RIGIDGROUP_MASK
,
aatype
)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist
=
utils
.
batched_gather
(
# shape (N, 8, 3)
all_atom_mask
.
astype
(
jnp
.
float32
),
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
1
)
gt_exists
=
jnp
.
min
(
gt_atoms_exist
,
axis
=-
1
)
*
group_exists
# (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
8
,
1
,
1
])
rots
[
0
,
0
,
0
]
=
-
1
rots
[
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose_rotation
(
geometry
.
Rot3Array
.
from_array
(
rots
))
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
restype_rigidgroup_rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
21
,
8
,
1
,
1
])
for
resname
,
_
in
residue_constants
.
residue_atom_renaming_swaps
.
items
():
restype
=
residue_constants
.
restype_order
[
residue_constants
.
restype_3to1
[
resname
]]
chi_idx
=
int
(
sum
(
residue_constants
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous
=
utils
.
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
)
ambiguity_rot
=
utils
.
batched_gather
(
restype_rigidgroup_rots
,
aatype
)
ambiguity_rot
=
geometry
.
Rot3Array
.
from_array
(
ambiguity_rot
)
# Create the alternative ground truth frames.
alt_gt_frames
=
gt_frames
.
compose_rotation
(
ambiguity_rot
)
fix_shape
=
lambda
x
:
jnp
.
reshape
(
x
,
aatype_in_shape
+
(
8
,))
# reshape back to original residue layout
gt_frames
=
jax
.
tree_map
(
fix_shape
,
gt_frames
)
gt_exists
=
fix_shape
(
gt_exists
)
group_exists
=
fix_shape
(
group_exists
)
residx_rigidgroup_is_ambiguous
=
fix_shape
(
residx_rigidgroup_is_ambiguous
)
alt_gt_frames
=
jax
.
tree_map
(
fix_shape
,
alt_gt_frames
)
return
{
'rigidgroups_gt_frames'
:
gt_frames
,
# Rigid (..., 8)
'rigidgroups_gt_exists'
:
gt_exists
,
# (..., 8)
'rigidgroups_group_exists'
:
group_exists
,
# (..., 8)
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
# (..., 8)
'rigidgroups_alt_gt_frames'
:
alt_gt_frames
,
# Rigid (..., 8)
}
def
torsion_angles_to_frames
(
aatype
:
jnp
.
ndarray
,
# (N)
backb_to_global
:
geometry
.
Rigid3Array
,
# (N)
torsion_angles_sin_cos
:
jnp
.
ndarray
# (N, 7, 2)
)
->
geometry
.
Rigid3Array
:
# (N, 8)
"""Compute rigid group frames from torsion angles."""
assert
len
(
aatype
.
shape
)
==
1
,
(
f
'Expected array of rank 1, got array with shape:
{
aatype
.
shape
}
.'
)
assert
len
(
backb_to_global
.
rotation
.
shape
)
==
1
,
(
f
'Expected array of rank 1, got array with shape: '
f
'
{
backb_to_global
.
rotation
.
shape
}
'
)
assert
len
(
torsion_angles_sin_cos
.
shape
)
==
3
,
(
f
'Expected array of rank 3, got array with shape: '
f
'
{
torsion_angles_sin_cos
.
shape
}
'
)
assert
torsion_angles_sin_cos
.
shape
[
1
]
==
7
,
(
f
'wrong shape
{
torsion_angles_sin_cos
.
shape
}
'
)
assert
torsion_angles_sin_cos
.
shape
[
2
]
==
2
,
(
f
'wrong shape
{
torsion_angles_sin_cos
.
shape
}
'
)
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m
=
utils
.
batched_gather
(
residue_constants
.
restype_rigid_group_default_frame
,
aatype
)
default_frames
=
geometry
.
Rigid3Array
.
from_array4x4
(
m
)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles
=
torsion_angles_sin_cos
[...,
0
]
cos_angles
=
torsion_angles_sin_cos
[...,
1
]
# insert zero rotation for backbone group.
num_residues
,
=
aatype
.
shape
sin_angles
=
jnp
.
concatenate
([
jnp
.
zeros
([
num_residues
,
1
]),
sin_angles
],
axis
=-
1
)
cos_angles
=
jnp
.
concatenate
([
jnp
.
ones
([
num_residues
,
1
]),
cos_angles
],
axis
=-
1
)
zeros
=
jnp
.
zeros_like
(
sin_angles
)
ones
=
jnp
.
ones_like
(
sin_angles
)
# all_rots are geometry.Rot3Array with shape (N, 8)
all_rots
=
geometry
.
Rot3Array
(
ones
,
zeros
,
zeros
,
zeros
,
cos_angles
,
-
sin_angles
,
zeros
,
sin_angles
,
cos_angles
)
# Apply rotations to the frames.
all_frames
=
default_frames
.
compose_rotation
(
all_rots
)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb
=
all_frames
[:,
4
]
chi2_frame_to_backb
=
chi1_frame_to_backb
@
all_frames
[:,
5
]
chi3_frame_to_backb
=
chi2_frame_to_backb
@
all_frames
[:,
6
]
chi4_frame_to_backb
=
chi3_frame_to_backb
@
all_frames
[:,
7
]
all_frames_to_backb
=
jax
.
tree_map
(
lambda
*
x
:
jnp
.
concatenate
(
x
,
axis
=-
1
),
all_frames
[:,
0
:
5
],
chi2_frame_to_backb
[:,
None
],
chi3_frame_to_backb
[:,
None
],
chi4_frame_to_backb
[:,
None
])
# Create the global frames.
# shape (N, 8)
all_frames_to_global
=
backb_to_global
[:,
None
]
@
all_frames_to_backb
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
aatype
:
jnp
.
ndarray
,
# (N)
all_frames_to_global
:
geometry
.
Rigid3Array
# (N, 8)
)
->
geometry
.
Vec3Array
:
# (N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx
=
utils
.
batched_gather
(
residue_constants
.
restype_atom14_to_rigid_group
,
aatype
)
group_mask
=
jax
.
nn
.
one_hot
(
residx_to_group_idx
,
num_classes
=
8
)
# shape (N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global
=
jax
.
tree_map
(
lambda
x
:
jnp
.
sum
(
x
[:,
None
,
:]
*
group_mask
,
axis
=-
1
),
all_frames_to_global
)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions
=
geometry
.
Vec3Array
.
from_array
(
utils
.
batched_gather
(
residue_constants
.
restype_atom14_rigid_group_positions
,
aatype
))
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions
=
map_atoms_to_global
.
apply_to_point
(
lit_positions
)
# Mask out non-existing atoms.
mask
=
utils
.
batched_gather
(
residue_constants
.
restype_atom14_mask
,
aatype
)
pred_positions
=
pred_positions
*
mask
return
pred_positions
def
extreme_ca_ca_distance_violations
(
positions
:
geometry
.
Vec3Array
,
# (N, 37(14))
mask
:
jnp
.
ndarray
,
# (N, 37(14))
residue_index
:
jnp
.
ndarray
,
# (N)
max_angstrom_tolerance
=
1.5
)
->
jnp
.
ndarray
:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos
=
positions
[:
-
1
,
1
]
# (N - 1,)
this_ca_mask
=
mask
[:
-
1
,
1
]
# (N - 1)
next_ca_pos
=
positions
[
1
:,
1
]
# (N - 1,)
next_ca_mask
=
mask
[
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
((
residue_index
[
1
:]
-
residue_index
[:
-
1
])
==
1.0
).
astype
(
jnp
.
float32
)
ca_ca_distance
=
geometry
.
euclidean_distance
(
this_ca_pos
,
next_ca_pos
,
1e-6
)
violations
=
(
ca_ca_distance
-
residue_constants
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
return
utils
.
mask_mean
(
mask
=
mask
,
value
=
violations
)
def
between_residue_bond_loss
(
pred_atom_positions
:
geometry
.
Vec3Array
,
# (N, 37(14))
pred_atom_mask
:
jnp
.
ndarray
,
# (N, 37(14))
residue_index
:
jnp
.
ndarray
,
# (N)
aatype
:
jnp
.
ndarray
,
# (N)
tolerance_factor_soft
=
12.0
,
tolerance_factor_hard
=
12.0
)
->
Dict
[
Text
,
jnp
.
ndarray
]:
"""Flat-bottom loss to penalize structural violations between residues."""
assert
len
(
pred_atom_positions
.
shape
)
==
2
assert
len
(
pred_atom_mask
.
shape
)
==
2
assert
len
(
residue_index
.
shape
)
==
1
assert
len
(
aatype
.
shape
)
==
1
# Get the positions of the relevant backbone atoms.
this_ca_pos
=
pred_atom_positions
[:
-
1
,
1
]
# (N - 1)
this_ca_mask
=
pred_atom_mask
[:
-
1
,
1
]
# (N - 1)
this_c_pos
=
pred_atom_positions
[:
-
1
,
2
]
# (N - 1)
this_c_mask
=
pred_atom_mask
[:
-
1
,
2
]
# (N - 1)
next_n_pos
=
pred_atom_positions
[
1
:,
0
]
# (N - 1)
next_n_mask
=
pred_atom_mask
[
1
:,
0
]
# (N - 1)
next_ca_pos
=
pred_atom_positions
[
1
:,
1
]
# (N - 1)
next_ca_mask
=
pred_atom_mask
[
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
((
residue_index
[
1
:]
-
residue_index
[:
-
1
])
==
1.0
).
astype
(
jnp
.
float32
)
# Compute loss for the C--N bond.
c_n_bond_length
=
geometry
.
euclidean_distance
(
this_c_pos
,
next_n_pos
,
1e-6
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline
=
(
aatype
[
1
:]
==
residue_constants
.
restype_order
[
'P'
]).
astype
(
jnp
.
float32
)
gt_length
=
(
(
1.
-
next_is_proline
)
*
residue_constants
.
between_res_bond_length_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_c_n
[
1
])
gt_stddev
=
(
(
1.
-
next_is_proline
)
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
+
next_is_proline
*
residue_constants
.
between_res_bond_length_stddev_c_n
[
1
])
c_n_bond_length_error
=
jnp
.
sqrt
(
1e-6
+
jnp
.
square
(
c_n_bond_length
-
gt_length
))
c_n_loss_per_residue
=
jax
.
nn
.
relu
(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
c_n_loss
=
jnp
.
sum
(
mask
*
c_n_loss_per_residue
)
/
(
jnp
.
sum
(
mask
)
+
1e-6
)
c_n_violation_mask
=
mask
*
(
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
))
# Compute loss for the angles.
c_ca_unit_vec
=
(
this_ca_pos
-
this_c_pos
).
normalized
(
1e-6
)
c_n_unit_vec
=
(
next_n_pos
-
this_c_pos
)
/
c_n_bond_length
n_ca_unit_vec
=
(
next_ca_pos
-
next_n_pos
).
normalized
(
1e-6
)
ca_c_n_cos_angle
=
c_ca_unit_vec
.
dot
(
c_n_unit_vec
)
gt_angle
=
residue_constants
.
between_res_cos_angles_ca_c_n
[
0
]
gt_stddev
=
residue_constants
.
between_res_bond_length_stddev_c_n
[
0
]
ca_c_n_cos_angle_error
=
jnp
.
sqrt
(
1e-6
+
jnp
.
square
(
ca_c_n_cos_angle
-
gt_angle
))
ca_c_n_loss_per_residue
=
jax
.
nn
.
relu
(
ca_c_n_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
ca_c_n_loss
=
jnp
.
sum
(
mask
*
ca_c_n_loss_per_residue
)
/
(
jnp
.
sum
(
mask
)
+
1e-6
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
c_n_ca_cos_angle
=
(
-
c_n_unit_vec
).
dot
(
n_ca_unit_vec
)
gt_angle
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
0
]
gt_stddev
=
residue_constants
.
between_res_cos_angles_c_n_ca
[
1
]
c_n_ca_cos_angle_error
=
jnp
.
sqrt
(
1e-6
+
jnp
.
square
(
c_n_ca_cos_angle
-
gt_angle
))
c_n_ca_loss_per_residue
=
jax
.
nn
.
relu
(
c_n_ca_cos_angle_error
-
tolerance_factor_soft
*
gt_stddev
)
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
c_n_ca_loss
=
jnp
.
sum
(
mask
*
c_n_ca_loss_per_residue
)
/
(
jnp
.
sum
(
mask
)
+
1e-6
)
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum
=
(
c_n_loss_per_residue
+
ca_c_n_loss_per_residue
+
c_n_ca_loss_per_residue
)
per_residue_loss_sum
=
0.5
*
(
jnp
.
pad
(
per_residue_loss_sum
,
[[
0
,
1
]])
+
jnp
.
pad
(
per_residue_loss_sum
,
[[
1
,
0
]]))
# Compute hard violations.
violation_mask
=
jnp
.
max
(
jnp
.
stack
([
c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
]),
axis
=
0
)
violation_mask
=
jnp
.
maximum
(
jnp
.
pad
(
violation_mask
,
[[
0
,
1
]]),
jnp
.
pad
(
violation_mask
,
[[
1
,
0
]]))
return
{
'c_n_loss_mean'
:
c_n_loss
,
# shape ()
'ca_c_n_loss_mean'
:
ca_c_n_loss
,
# shape ()
'c_n_ca_loss_mean'
:
c_n_ca_loss
,
# shape ()
'per_residue_loss_sum'
:
per_residue_loss_sum
,
# shape (N)
'per_residue_violation_mask'
:
violation_mask
# shape (N)
}
def
between_residue_clash_loss
(
pred_positions
:
geometry
.
Vec3Array
,
# (N, 14)
atom_exists
:
jnp
.
ndarray
,
# (N, 14)
atom_radius
:
jnp
.
ndarray
,
# (N, 14)
residue_index
:
jnp
.
ndarray
,
# (N)
asym_id
:
jnp
.
ndarray
,
# (N)
overlap_tolerance_soft
=
1.5
,
overlap_tolerance_hard
=
1.5
)
->
Dict
[
Text
,
jnp
.
ndarray
]:
"""Loss to penalize steric clashes between residues."""
assert
len
(
pred_positions
.
shape
)
==
2
assert
len
(
atom_exists
.
shape
)
==
2
assert
len
(
atom_radius
.
shape
)
==
2
assert
len
(
residue_index
.
shape
)
==
1
# Create the distance matrix.
# (N, N, 14, 14)
dists
=
geometry
.
euclidean_distance
(
pred_positions
[:,
None
,
:,
None
],
pred_positions
[
None
,
:,
None
,
:],
1e-10
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask
=
(
atom_exists
[:,
None
,
:,
None
]
*
atom_exists
[
None
,
:,
None
,
:])
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask
*=
(
residue_index
[:,
None
,
None
,
None
]
<
residue_index
[
None
,
:,
None
,
None
])
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
jax
.
nn
.
one_hot
(
2
,
num_classes
=
14
)
n_one_hot
=
jax
.
nn
.
one_hot
(
0
,
num_classes
=
14
)
neighbour_mask
=
((
residue_index
[:,
None
]
+
1
)
==
residue_index
[
None
,
:])
neighbour_mask
&=
(
asym_id
[:,
None
]
==
asym_id
[
None
,
:])
neighbour_mask
=
neighbour_mask
[...,
None
,
None
]
c_n_bonds
=
neighbour_mask
*
c_one_hot
[
None
,
None
,
:,
None
]
*
n_one_hot
[
None
,
None
,
None
,
:]
dists_mask
*=
(
1.
-
c_n_bonds
)
# Disulfide bridge between two cysteines is no clash.
cys_sg_idx
=
residue_constants
.
restype_name_to_atom14_names
[
'CYS'
].
index
(
'SG'
)
cys_sg_one_hot
=
jax
.
nn
.
one_hot
(
cys_sg_idx
,
num_classes
=
14
)
disulfide_bonds
=
(
cys_sg_one_hot
[
None
,
None
,
:,
None
]
*
cys_sg_one_hot
[
None
,
None
,
None
,
:])
dists_mask
*=
(
1.
-
disulfide_bonds
)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound
=
dists_mask
*
(
atom_radius
[:,
None
,
:,
None
]
+
atom_radius
[
None
,
:,
None
,
:])
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error
=
dists_mask
*
jax
.
nn
.
relu
(
dists_lower_bound
-
overlap_tolerance_soft
-
dists
)
# Compute the mean loss.
# shape ()
mean_loss
=
(
jnp
.
sum
(
dists_to_low_error
)
/
(
1e-6
+
jnp
.
sum
(
dists_mask
)))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
(
jnp
.
sum
(
dists_to_low_error
,
axis
=
[
0
,
2
])
+
jnp
.
sum
(
dists_to_low_error
,
axis
=
[
1
,
3
]))
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask
=
dists_mask
*
(
dists
<
(
dists_lower_bound
-
overlap_tolerance_hard
))
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask
=
jnp
.
maximum
(
jnp
.
max
(
clash_mask
,
axis
=
[
0
,
2
]),
jnp
.
max
(
clash_mask
,
axis
=
[
1
,
3
]))
return
{
'mean_loss'
:
mean_loss
,
# shape ()
'per_atom_loss_sum'
:
per_atom_loss_sum
,
# shape (N, 14)
'per_atom_clash_mask'
:
per_atom_clash_mask
# shape (N, 14)
}
def
within_residue_violations
(
pred_positions
:
geometry
.
Vec3Array
,
# (N, 14)
atom_exists
:
jnp
.
ndarray
,
# (N, 14)
dists_lower_bound
:
jnp
.
ndarray
,
# (N, 14, 14)
dists_upper_bound
:
jnp
.
ndarray
,
# (N, 14, 14)
tighten_bounds_for_loss
=
0.0
,
)
->
Dict
[
Text
,
jnp
.
ndarray
]:
"""Find within-residue violations."""
assert
len
(
pred_positions
.
shape
)
==
2
assert
len
(
atom_exists
.
shape
)
==
2
assert
len
(
dists_lower_bound
.
shape
)
==
3
assert
len
(
dists_upper_bound
.
shape
)
==
3
# Compute the mask for each residue.
# shape (N, 14, 14)
dists_masks
=
(
1.
-
jnp
.
eye
(
14
,
14
)[
None
])
dists_masks
*=
(
atom_exists
[:,
:,
None
]
*
atom_exists
[:,
None
,
:])
# Distance matrix
# shape (N, 14, 14)
dists
=
geometry
.
euclidean_distance
(
pred_positions
[:,
:,
None
],
pred_positions
[:,
None
,
:],
1e-10
)
# Compute the loss.
# shape (N, 14, 14)
dists_to_low_error
=
jax
.
nn
.
relu
(
dists_lower_bound
+
tighten_bounds_for_loss
-
dists
)
dists_to_high_error
=
jax
.
nn
.
relu
(
dists
+
tighten_bounds_for_loss
-
dists_upper_bound
)
loss
=
dists_masks
*
(
dists_to_low_error
+
dists_to_high_error
)
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum
=
(
jnp
.
sum
(
loss
,
axis
=
1
)
+
jnp
.
sum
(
loss
,
axis
=
2
))
# Compute the violations mask.
# shape (N, 14, 14)
violations
=
dists_masks
*
((
dists
<
dists_lower_bound
)
|
(
dists
>
dists_upper_bound
))
# Compute the per atom violations.
# shape (N, 14)
per_atom_violations
=
jnp
.
maximum
(
jnp
.
max
(
violations
,
axis
=
1
),
jnp
.
max
(
violations
,
axis
=
2
))
return
{
'per_atom_loss_sum'
:
per_atom_loss_sum
,
# shape (N, 14)
'per_atom_violations'
:
per_atom_violations
# shape (N, 14)
}
def
find_optimal_renaming
(
gt_positions
:
geometry
.
Vec3Array
,
# (N, 14)
alt_gt_positions
:
geometry
.
Vec3Array
,
# (N, 14)
atom_is_ambiguous
:
jnp
.
ndarray
,
# (N, 14)
gt_exists
:
jnp
.
ndarray
,
# (N, 14)
pred_positions
:
geometry
.
Vec3Array
,
# (N, 14)
)
->
jnp
.
ndarray
:
# (N):
"""Find optimal renaming for ground truth that maximizes LDDT."""
assert
len
(
gt_positions
.
shape
)
==
2
assert
len
(
alt_gt_positions
.
shape
)
==
2
assert
len
(
atom_is_ambiguous
.
shape
)
==
2
assert
len
(
gt_exists
.
shape
)
==
2
assert
len
(
pred_positions
.
shape
)
==
2
# Create the pred distance matrix.
# shape (N, N, 14, 14)
pred_dists
=
geometry
.
euclidean_distance
(
pred_positions
[:,
None
,
:,
None
],
pred_positions
[
None
,
:,
None
,
:],
1e-10
)
# Compute distances for ground truth with original and alternative names.
# shape (N, N, 14, 14)
gt_dists
=
geometry
.
euclidean_distance
(
gt_positions
[:,
None
,
:,
None
],
gt_positions
[
None
,
:,
None
,
:],
1e-10
)
alt_gt_dists
=
geometry
.
euclidean_distance
(
alt_gt_positions
[:,
None
,
:,
None
],
alt_gt_positions
[
None
,
:,
None
,
:],
1e-10
)
# Compute LDDT's.
# shape (N, N, 14, 14)
lddt
=
jnp
.
sqrt
(
1e-10
+
squared_difference
(
pred_dists
,
gt_dists
))
alt_lddt
=
jnp
.
sqrt
(
1e-10
+
squared_difference
(
pred_dists
,
alt_gt_dists
))
# Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms
# in cols.
# shape (N ,N, 14, 14)
mask
=
(
gt_exists
[:,
None
,
:,
None
]
*
# rows
atom_is_ambiguous
[:,
None
,
:,
None
]
*
# rows
gt_exists
[
None
,
:,
None
,
:]
*
# cols
(
1.
-
atom_is_ambiguous
[
None
,
:,
None
,
:]))
# cols
# Aggregate distances for each residue to the non-amibuguous atoms.
# shape (N)
per_res_lddt
=
jnp
.
sum
(
mask
*
lddt
,
axis
=
[
1
,
2
,
3
])
alt_per_res_lddt
=
jnp
.
sum
(
mask
*
alt_lddt
,
axis
=
[
1
,
2
,
3
])
# Decide for each residue, whether alternative naming is better.
# shape (N)
alt_naming_is_better
=
(
alt_per_res_lddt
<
per_res_lddt
).
astype
(
jnp
.
float32
)
return
alt_naming_is_better
# shape (N)
def
frame_aligned_point_error
(
pred_frames
:
geometry
.
Rigid3Array
,
# shape (num_frames)
target_frames
:
geometry
.
Rigid3Array
,
# shape (num_frames)
frames_mask
:
jnp
.
ndarray
,
# shape (num_frames)
pred_positions
:
geometry
.
Vec3Array
,
# shape (num_positions)
target_positions
:
geometry
.
Vec3Array
,
# shape (num_positions)
positions_mask
:
jnp
.
ndarray
,
# shape (num_positions)
pair_mask
:
jnp
.
ndarray
,
# shape (num_frames, num_posiitons)
l1_clamp_distance
:
float
,
length_scale
=
20.
,
epsilon
=
1e-4
)
->
jnp
.
ndarray
:
# shape ()
"""Measure point error under different alignements.
Computes error between two structures with B points
under A alignments derived form the given pairs of frames.
Args:
pred_frames: num_frames reference frames for 'pred_positions'.
target_frames: num_frames reference frames for 'target_positions'.
frames_mask: Mask for frame pairs to use.
pred_positions: num_positions predicted positions of the structure.
target_positions: num_positions target positions of the structure.
positions_mask: Mask on which positions to score.
pair_mask: A (num_frames, num_positions) mask to use in the loss, useful
for separating intra from inter chain losses.
l1_clamp_distance: Distance cutoff on error beyond which gradients will
be zero.
length_scale: length scale to divide loss by.
epsilon: small value used to regularize denominator for masked average.
Returns:
Masked Frame aligned point error.
"""
# For now we do not allow any batch dimensions.
assert
len
(
pred_frames
.
rotation
.
shape
)
==
1
assert
len
(
target_frames
.
rotation
.
shape
)
==
1
assert
frames_mask
.
ndim
==
1
assert
pred_positions
.
x
.
ndim
==
1
assert
target_positions
.
x
.
ndim
==
1
assert
positions_mask
.
ndim
==
1
# Compute array of predicted positions in the predicted frames.
# geometry.Vec3Array (num_frames, num_positions)
local_pred_pos
=
pred_frames
[:,
None
].
inverse
().
apply_to_point
(
pred_positions
[
None
,
:])
# Compute array of target positions in the target frames.
# geometry.Vec3Array (num_frames, num_positions)
local_target_pos
=
target_frames
[:,
None
].
inverse
().
apply_to_point
(
target_positions
[
None
,
:])
# Compute errors between the structures.
# jnp.ndarray (num_frames, num_positions)
error_dist
=
geometry
.
euclidean_distance
(
local_pred_pos
,
local_target_pos
,
epsilon
)
clipped_error_dist
=
jnp
.
clip
(
error_dist
,
0
,
l1_clamp_distance
)
normed_error
=
clipped_error_dist
/
length_scale
normed_error
*=
jnp
.
expand_dims
(
frames_mask
,
axis
=-
1
)
normed_error
*=
jnp
.
expand_dims
(
positions_mask
,
axis
=-
2
)
if
pair_mask
is
not
None
:
normed_error
*=
pair_mask
mask
=
(
jnp
.
expand_dims
(
frames_mask
,
axis
=-
1
)
*
jnp
.
expand_dims
(
positions_mask
,
axis
=-
2
))
if
pair_mask
is
not
None
:
mask
*=
pair_mask
normalization_factor
=
jnp
.
sum
(
mask
,
axis
=
(
-
1
,
-
2
))
return
(
jnp
.
sum
(
normed_error
,
axis
=
(
-
2
,
-
1
))
/
(
epsilon
+
normalization_factor
))
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in residue_constants.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
residue_constants
.
restypes
:
residue_name
=
residue_constants
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
residue_constants
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
residue_constants
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
jnp
.
asarray
(
chi_atom_indices
)
def
compute_chi_angles
(
positions
:
geometry
.
Vec3Array
,
mask
:
geometry
.
Vec3Array
,
aatype
:
geometry
.
Vec3Array
):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, residue_constants.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, residue_constants.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert
positions
.
shape
[
-
1
]
==
residue_constants
.
atom_type_num
assert
mask
.
shape
[
-
1
]
==
residue_constants
.
atom_type_num
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices
=
get_chi_atom_indices
()
# Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4].
atom_indices
=
utils
.
batched_gather
(
params
=
chi_atom_indices
,
indices
=
aatype
,
axis
=
0
)
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms
=
jax
.
tree_map
(
lambda
x
:
utils
.
batched_gather
(
# pylint: disable=g-long-lambda
params
=
x
,
indices
=
atom_indices
,
axis
=-
1
,
batch_dims
=
1
),
positions
)
a
,
b
,
c
,
d
=
[
chi_angle_atoms
[...,
i
]
for
i
in
range
(
4
)]
chi_angles
=
geometry
.
dihedral_angle
(
a
,
b
,
c
,
d
)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask
=
list
(
residue_constants
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.0
,
0.0
,
0.0
,
0.0
])
chi_angles_mask
=
jnp
.
asarray
(
chi_angles_mask
)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask
=
utils
.
batched_gather
(
params
=
chi_angles_mask
,
indices
=
aatype
,
axis
=
0
)
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask
=
utils
.
batched_gather
(
params
=
mask
,
indices
=
atom_indices
,
axis
=-
1
,
batch_dims
=
1
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask
=
jnp
.
prod
(
chi_angle_atoms_mask
,
axis
=
[
-
1
])
chi_mask
=
chi_mask
*
chi_angle_atoms_mask
.
astype
(
jnp
.
float32
)
return
chi_angles
,
chi_mask
def
make_transform_from_reference
(
a_xyz
:
geometry
.
Vec3Array
,
b_xyz
:
geometry
.
Vec3Array
,
c_xyz
:
geometry
.
Vec3Array
)
->
geometry
.
Rigid3Array
:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
c_xyz
-
b_xyz
,
a_xyz
-
b_xyz
)
return
geometry
.
Rigid3Array
(
rotation
,
b_xyz
)
alphafold/model/all_atom_test.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for all_atom."""
from
absl.testing
import
absltest
from
absl.testing
import
parameterized
from
alphafold.model
import
all_atom
from
alphafold.model
import
r3
import
numpy
as
np
L1_CLAMP_DISTANCE
=
10
def
get_identity_rigid
(
shape
):
"""Returns identity rigid transform."""
ones
=
np
.
ones
(
shape
)
zeros
=
np
.
zeros
(
shape
)
rot
=
r3
.
Rots
(
ones
,
zeros
,
zeros
,
zeros
,
ones
,
zeros
,
zeros
,
zeros
,
ones
)
trans
=
r3
.
Vecs
(
zeros
,
zeros
,
zeros
)
return
r3
.
Rigids
(
rot
,
trans
)
def
get_global_rigid_transform
(
rot_angle
,
translation
,
bcast_dims
):
"""Returns rigid transform that globally rotates/translates by same amount."""
rot_angle
=
np
.
asarray
(
rot_angle
)
translation
=
np
.
asarray
(
translation
)
if
bcast_dims
:
for
_
in
range
(
bcast_dims
):
rot_angle
=
np
.
expand_dims
(
rot_angle
,
0
)
translation
=
np
.
expand_dims
(
translation
,
0
)
sin_angle
=
np
.
sin
(
np
.
deg2rad
(
rot_angle
))
cos_angle
=
np
.
cos
(
np
.
deg2rad
(
rot_angle
))
ones
=
np
.
ones_like
(
sin_angle
)
zeros
=
np
.
zeros_like
(
sin_angle
)
rot
=
r3
.
Rots
(
ones
,
zeros
,
zeros
,
zeros
,
cos_angle
,
-
sin_angle
,
zeros
,
sin_angle
,
cos_angle
)
trans
=
r3
.
Vecs
(
translation
[...,
0
],
translation
[...,
1
],
translation
[...,
2
])
return
r3
.
Rigids
(
rot
,
trans
)
class
AllAtomTest
(
parameterized
.
TestCase
,
absltest
.
TestCase
):
@
parameterized
.
named_parameters
(
(
'identity'
,
0
,
[
0
,
0
,
0
]),
(
'rot_90'
,
90
,
[
0
,
0
,
0
]),
(
'trans_10'
,
0
,
[
0
,
0
,
10
]),
(
'rot_174_trans_1'
,
174
,
[
1
,
1
,
1
]))
def
test_frame_aligned_point_error_perfect_on_global_transform
(
self
,
rot_angle
,
translation
):
"""Tests global transform between target and preds gives perfect score."""
# pylint: disable=bad-whitespace
target_positions
=
np
.
array
(
[[
21.182
,
23.095
,
19.731
],
[
22.055
,
20.919
,
17.294
],
[
24.599
,
20.005
,
15.041
],
[
25.567
,
18.214
,
12.166
],
[
28.063
,
17.082
,
10.043
],
[
28.779
,
15.569
,
6.985
],
[
30.581
,
13.815
,
4.612
],
[
29.258
,
12.193
,
2.296
]])
# pylint: enable=bad-whitespace
global_rigid_transform
=
get_global_rigid_transform
(
rot_angle
,
translation
,
1
)
target_positions
=
r3
.
vecs_from_tensor
(
target_positions
)
pred_positions
=
r3
.
rigids_mul_vecs
(
global_rigid_transform
,
target_positions
)
positions_mask
=
np
.
ones
(
target_positions
.
x
.
shape
[
0
])
target_frames
=
get_identity_rigid
(
10
)
pred_frames
=
r3
.
rigids_mul_rigids
(
global_rigid_transform
,
target_frames
)
frames_mask
=
np
.
ones
(
10
)
fape
=
all_atom
.
frame_aligned_point_error
(
pred_frames
,
target_frames
,
frames_mask
,
pred_positions
,
target_positions
,
positions_mask
,
L1_CLAMP_DISTANCE
,
L1_CLAMP_DISTANCE
,
epsilon
=
0
)
self
.
assertAlmostEqual
(
fape
,
0.
)
@
parameterized
.
named_parameters
(
(
'identity'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
0.
),
(
'shift_2.5'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
2.5
,
0
,
0
],
[
7.5
,
0
,
0
],
[
7.5
,
0
,
0
]],
0.25
),
(
'shift_5'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
5
,
0
,
0
],
[
10
,
0
,
0
],
[
15
,
0
,
0
]],
0.5
),
(
'shift_10'
,
[[
0
,
0
,
0
],
[
5
,
0
,
0
],
[
10
,
0
,
0
]],
[[
10
,
0
,
0
],
[
15
,
0
,
0
],
[
0
,
0
,
0
]],
1.
))
def
test_frame_aligned_point_error_matches_expected
(
self
,
target_positions
,
pred_positions
,
expected_alddt
):
"""Tests score matches expected."""
target_frames
=
get_identity_rigid
(
2
)
pred_frames
=
target_frames
frames_mask
=
np
.
ones
(
2
)
target_positions
=
r3
.
vecs_from_tensor
(
np
.
array
(
target_positions
))
pred_positions
=
r3
.
vecs_from_tensor
(
np
.
array
(
pred_positions
))
positions_mask
=
np
.
ones
(
target_positions
.
x
.
shape
[
0
])
alddt
=
all_atom
.
frame_aligned_point_error
(
pred_frames
,
target_frames
,
frames_mask
,
pred_positions
,
target_positions
,
positions_mask
,
L1_CLAMP_DISTANCE
,
L1_CLAMP_DISTANCE
,
epsilon
=
0
)
self
.
assertAlmostEqual
(
alddt
,
expected_alddt
)
if
__name__
==
'__main__'
:
absltest
.
main
()
alphafold/model/common_modules.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of common Haiku modules for use in protein folding."""
import
numbers
from
typing
import
Union
,
Sequence
import
haiku
as
hk
import
jax.numpy
as
jnp
import
numpy
as
np
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR
=
np
.
asarray
(.
87962566103423978
,
dtype
=
np
.
float32
)
def
get_initializer_scale
(
initializer_name
,
input_shape
):
"""Get Initializer for weights and scale to multiply activations by."""
if
initializer_name
==
'zeros'
:
w_init
=
hk
.
initializers
.
Constant
(
0.0
)
else
:
# fan-in scaling
scale
=
1.
for
channel_dim
in
input_shape
:
scale
/=
channel_dim
if
initializer_name
==
'relu'
:
scale
*=
2
noise_scale
=
scale
stddev
=
np
.
sqrt
(
noise_scale
)
# Adjust stddev for truncation.
stddev
=
stddev
/
TRUNCATED_NORMAL_STDDEV_FACTOR
w_init
=
hk
.
initializers
.
TruncatedNormal
(
mean
=
0.0
,
stddev
=
stddev
)
return
w_init
class
Linear
(
hk
.
Module
):
"""Protein folding specific Linear module.
This differs from the standard Haiku Linear in a few ways:
* It supports inputs and outputs of arbitrary rank
* Initializers are specified by strings
"""
def
__init__
(
self
,
num_output
:
Union
[
int
,
Sequence
[
int
]],
initializer
:
str
=
'linear'
,
num_input_dims
:
int
=
1
,
use_bias
:
bool
=
True
,
bias_init
:
float
=
0.
,
precision
=
None
,
name
:
str
=
'linear'
):
"""Constructs Linear Module.
Args:
num_output: Number of output channels. Can be tuple when outputting
multiple dimensions.
initializer: What initializer to use, should be one of {'linear', 'relu',
'zeros'}
num_input_dims: Number of dimensions from the end to project.
use_bias: Whether to include trainable bias
bias_init: Value used to initialize bias.
precision: What precision to use for matrix multiplication, defaults
to None.
name: Name of module, used for name scopes.
"""
super
().
__init__
(
name
=
name
)
if
isinstance
(
num_output
,
numbers
.
Integral
):
self
.
output_shape
=
(
num_output
,)
else
:
self
.
output_shape
=
tuple
(
num_output
)
self
.
initializer
=
initializer
self
.
use_bias
=
use_bias
self
.
bias_init
=
bias_init
self
.
num_input_dims
=
num_input_dims
self
.
num_output_dims
=
len
(
self
.
output_shape
)
self
.
precision
=
precision
def
__call__
(
self
,
inputs
):
"""Connects Module.
Args:
inputs: Tensor with at least num_input_dims dimensions.
Returns:
output of shape [...] + num_output.
"""
num_input_dims
=
self
.
num_input_dims
if
self
.
num_input_dims
>
0
:
in_shape
=
inputs
.
shape
[
-
self
.
num_input_dims
:]
else
:
in_shape
=
()
weight_init
=
get_initializer_scale
(
self
.
initializer
,
in_shape
)
in_letters
=
'abcde'
[:
self
.
num_input_dims
]
out_letters
=
'hijkl'
[:
self
.
num_output_dims
]
weight_shape
=
in_shape
+
self
.
output_shape
weights
=
hk
.
get_parameter
(
'weights'
,
weight_shape
,
inputs
.
dtype
,
weight_init
)
equation
=
f
'...
{
in_letters
}
,
{
in_letters
}{
out_letters
}
->...
{
out_letters
}
'
output
=
jnp
.
einsum
(
equation
,
inputs
,
weights
,
precision
=
self
.
precision
)
if
self
.
use_bias
:
bias
=
hk
.
get_parameter
(
'bias'
,
self
.
output_shape
,
inputs
.
dtype
,
hk
.
initializers
.
Constant
(
self
.
bias_init
))
output
+=
bias
return
output
alphafold/model/config.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model config."""
import
copy
from
alphafold.model.tf
import
shape_placeholders
import
ml_collections
NUM_RES
=
shape_placeholders
.
NUM_RES
NUM_MSA_SEQ
=
shape_placeholders
.
NUM_MSA_SEQ
NUM_EXTRA_SEQ
=
shape_placeholders
.
NUM_EXTRA_SEQ
NUM_TEMPLATES
=
shape_placeholders
.
NUM_TEMPLATES
def
model_config
(
name
:
str
)
->
ml_collections
.
ConfigDict
:
"""Get the ConfigDict of a CASP14 model."""
if
'multimer'
in
name
:
return
CONFIG_MULTIMER
if
name
not
in
CONFIG_DIFFS
:
raise
ValueError
(
f
'Invalid model name
{
name
}
.'
)
cfg
=
copy
.
deepcopy
(
CONFIG
)
cfg
.
update_from_flattened_dict
(
CONFIG_DIFFS
[
name
])
return
cfg
MODEL_PRESETS
=
{
'monomer'
:
(
'model_1'
,
'model_2'
,
'model_3'
,
'model_4'
,
'model_5'
,
),
'monomer_ptm'
:
(
'model_1_ptm'
,
'model_2_ptm'
,
'model_3_ptm'
,
'model_4_ptm'
,
'model_5_ptm'
,
),
'multimer'
:
(
'model_1_multimer_v2'
,
'model_2_multimer_v2'
,
'model_3_multimer_v2'
,
'model_4_multimer_v2'
,
'model_5_multimer_v2'
,
),
}
MODEL_PRESETS
[
'monomer_casp14'
]
=
MODEL_PRESETS
[
'monomer'
]
CONFIG_DIFFS
=
{
'model_1'
:
{
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.1
'data.common.max_extra_msa'
:
5120
,
'data.common.reduce_msa_clusters_by_max_templates'
:
True
,
'data.common.use_templates'
:
True
,
'model.embeddings_and_evoformer.template.embed_torsion_angles'
:
True
,
'model.embeddings_and_evoformer.template.enabled'
:
True
},
'model_2'
:
{
# Jumper et al. (2021) Suppl. Table 5, Model 1.1.2
'data.common.reduce_msa_clusters_by_max_templates'
:
True
,
'data.common.use_templates'
:
True
,
'model.embeddings_and_evoformer.template.embed_torsion_angles'
:
True
,
'model.embeddings_and_evoformer.template.enabled'
:
True
},
'model_3'
:
{
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.1
'data.common.max_extra_msa'
:
5120
,
},
'model_4'
:
{
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.2
'data.common.max_extra_msa'
:
5120
,
},
'model_5'
:
{
# Jumper et al. (2021) Suppl. Table 5, Model 1.2.3
},
# The following models are fine-tuned from the corresponding models above
# with an additional predicted_aligned_error head that can produce
# predicted TM-score (pTM) and predicted aligned errors.
'model_1_ptm'
:
{
'data.common.max_extra_msa'
:
5120
,
'data.common.reduce_msa_clusters_by_max_templates'
:
True
,
'data.common.use_templates'
:
True
,
'model.embeddings_and_evoformer.template.embed_torsion_angles'
:
True
,
'model.embeddings_and_evoformer.template.enabled'
:
True
,
'model.heads.predicted_aligned_error.weight'
:
0.1
},
'model_2_ptm'
:
{
'data.common.reduce_msa_clusters_by_max_templates'
:
True
,
'data.common.use_templates'
:
True
,
'model.embeddings_and_evoformer.template.embed_torsion_angles'
:
True
,
'model.embeddings_and_evoformer.template.enabled'
:
True
,
'model.heads.predicted_aligned_error.weight'
:
0.1
},
'model_3_ptm'
:
{
'data.common.max_extra_msa'
:
5120
,
'model.heads.predicted_aligned_error.weight'
:
0.1
},
'model_4_ptm'
:
{
'data.common.max_extra_msa'
:
5120
,
'model.heads.predicted_aligned_error.weight'
:
0.1
},
'model_5_ptm'
:
{
'model.heads.predicted_aligned_error.weight'
:
0.1
}
}
CONFIG
=
ml_collections
.
ConfigDict
({
'data'
:
{
'common'
:
{
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'max_extra_msa'
:
1024
,
'msa_cluster_features'
:
True
,
'num_recycle'
:
3
,
'reduce_msa_clusters_by_max_templates'
:
False
,
'resample_msa_in_recycling'
:
True
,
'template_features'
:
[
'template_all_atom_positions'
,
'template_sum_probs'
,
'template_aatype'
,
'template_all_atom_masks'
,
'template_domain_names'
],
'unsupervised_features'
:
[
'aatype'
,
'residue_index'
,
'sequence'
,
'msa'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
,
'between_segment_residues'
,
'deletion_matrix'
],
'use_templates'
:
False
,
},
'eval'
:
{
'feat'
:
{
'aatype'
:
[
NUM_RES
],
'all_atom_mask'
:
[
NUM_RES
,
None
],
'all_atom_positions'
:
[
NUM_RES
,
None
,
None
],
'alt_chi_angles'
:
[
NUM_RES
,
None
],
'atom14_alt_gt_exists'
:
[
NUM_RES
,
None
],
'atom14_alt_gt_positions'
:
[
NUM_RES
,
None
,
None
],
'atom14_atom_exists'
:
[
NUM_RES
,
None
],
'atom14_atom_is_ambiguous'
:
[
NUM_RES
,
None
],
'atom14_gt_exists'
:
[
NUM_RES
,
None
],
'atom14_gt_positions'
:
[
NUM_RES
,
None
,
None
],
'atom37_atom_exists'
:
[
NUM_RES
,
None
],
'backbone_affine_mask'
:
[
NUM_RES
],
'backbone_affine_tensor'
:
[
NUM_RES
,
None
],
'bert_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'chi_angles'
:
[
NUM_RES
,
None
],
'chi_mask'
:
[
NUM_RES
,
None
],
'extra_deletion_value'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_has_deletion'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa_mask'
:
[
NUM_EXTRA_SEQ
,
NUM_RES
],
'extra_msa_row_mask'
:
[
NUM_EXTRA_SEQ
],
'is_distillation'
:
[],
'msa_feat'
:
[
NUM_MSA_SEQ
,
NUM_RES
,
None
],
'msa_mask'
:
[
NUM_MSA_SEQ
,
NUM_RES
],
'msa_row_mask'
:
[
NUM_MSA_SEQ
],
'pseudo_beta'
:
[
NUM_RES
,
None
],
'pseudo_beta_mask'
:
[
NUM_RES
],
'random_crop_to_size_seed'
:
[
None
],
'residue_index'
:
[
NUM_RES
],
'residx_atom14_to_atom37'
:
[
NUM_RES
,
None
],
'residx_atom37_to_atom14'
:
[
NUM_RES
,
None
],
'resolution'
:
[],
'rigidgroups_alt_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'rigidgroups_group_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_group_is_ambiguous'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_exists'
:
[
NUM_RES
,
None
],
'rigidgroups_gt_frames'
:
[
NUM_RES
,
None
,
None
],
'seq_length'
:
[],
'seq_mask'
:
[
NUM_RES
],
'target_feat'
:
[
NUM_RES
,
None
],
'template_aatype'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_all_atom_masks'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_all_atom_positions'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
,
None
],
'template_backbone_affine_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_backbone_affine_tensor'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_mask'
:
[
NUM_TEMPLATES
],
'template_pseudo_beta'
:
[
NUM_TEMPLATES
,
NUM_RES
,
None
],
'template_pseudo_beta_mask'
:
[
NUM_TEMPLATES
,
NUM_RES
],
'template_sum_probs'
:
[
NUM_TEMPLATES
,
None
],
'true_msa'
:
[
NUM_MSA_SEQ
,
NUM_RES
]
},
'fixed_size'
:
True
,
'subsample_templates'
:
False
,
# We want top templates.
'masked_msa_replace_fraction'
:
0.15
,
'max_msa_clusters'
:
512
,
'max_templates'
:
4
,
'num_ensemble'
:
1
,
},
},
'model'
:
{
'embeddings_and_evoformer'
:
{
'evoformer_num_block'
:
48
,
'evoformer'
:
{
'msa_row_attention_with_pair_bias'
:
{
'dropout_rate'
:
0.15
,
'gating'
:
True
,
'num_head'
:
8
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'msa_column_attention'
:
{
'dropout_rate'
:
0.0
,
'gating'
:
True
,
'num_head'
:
8
,
'orientation'
:
'per_column'
,
'shared_dropout'
:
True
},
'msa_transition'
:
{
'dropout_rate'
:
0.0
,
'num_intermediate_factor'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'outer_product_mean'
:
{
'first'
:
False
,
'chunk_size'
:
128
,
'dropout_rate'
:
0.0
,
'num_outer_channel'
:
32
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_attention_starting_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_attention_ending_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_column'
,
'shared_dropout'
:
True
},
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
'num_intermediate_factor'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
}
},
'extra_msa_channel'
:
64
,
'extra_msa_stack_num_block'
:
4
,
'max_relative_feature'
:
32
,
'msa_channel'
:
256
,
'pair_channel'
:
128
,
'prev_pos'
:
{
'min_bin'
:
3.25
,
'max_bin'
:
20.75
,
'num_bins'
:
15
},
'recycle_features'
:
True
,
'recycle_pos'
:
True
,
'seq_channel'
:
384
,
'template'
:
{
'attention'
:
{
'gating'
:
False
,
'key_dim'
:
64
,
'num_head'
:
4
,
'value_dim'
:
64
},
'dgram_features'
:
{
'min_bin'
:
3.25
,
'max_bin'
:
50.75
,
'num_bins'
:
39
},
'embed_torsion_angles'
:
False
,
'enabled'
:
False
,
'template_pair_stack'
:
{
'num_block'
:
2
,
'triangle_attention_starting_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'key_dim'
:
64
,
'num_head'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
,
'value_dim'
:
64
},
'triangle_attention_ending_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'key_dim'
:
64
,
'num_head'
:
4
,
'orientation'
:
'per_column'
,
'shared_dropout'
:
True
,
'value_dim'
:
64
},
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
'num_intermediate_factor'
:
2
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
}
},
'max_templates'
:
4
,
'subbatch_size'
:
128
,
'use_template_unit_vector'
:
False
,
}
},
'global_config'
:
{
'deterministic'
:
False
,
'multimer_mode'
:
False
,
'subbatch_size'
:
4
,
'use_remat'
:
False
,
'zero_init'
:
True
},
'heads'
:
{
'distogram'
:
{
'first_break'
:
2.3125
,
'last_break'
:
21.6875
,
'num_bins'
:
64
,
'weight'
:
0.3
},
'predicted_aligned_error'
:
{
# `num_bins - 1` bins uniformly space the
# [0, max_error_bin A] range.
# The final bin covers [max_error_bin A, +infty]
# 31A gives bins with 0.5A width.
'max_error_bin'
:
31.
,
'num_bins'
:
64
,
'num_channels'
:
128
,
'filter_by_resolution'
:
True
,
'min_resolution'
:
0.1
,
'max_resolution'
:
3.0
,
'weight'
:
0.0
,
},
'experimentally_resolved'
:
{
'filter_by_resolution'
:
True
,
'max_resolution'
:
3.0
,
'min_resolution'
:
0.1
,
'weight'
:
0.01
},
'structure_module'
:
{
'num_layer'
:
8
,
'fape'
:
{
'clamp_distance'
:
10.0
,
'clamp_type'
:
'relu'
,
'loss_unit_distance'
:
10.0
},
'angle_norm_weight'
:
0.01
,
'chi_weight'
:
0.5
,
'clash_overlap_tolerance'
:
1.5
,
'compute_in_graph_metrics'
:
True
,
'dropout'
:
0.1
,
'num_channel'
:
384
,
'num_head'
:
12
,
'num_layer_in_transition'
:
3
,
'num_point_qk'
:
4
,
'num_point_v'
:
8
,
'num_scalar_qk'
:
16
,
'num_scalar_v'
:
16
,
'position_scale'
:
10.0
,
'sidechain'
:
{
'atom_clamp_distance'
:
10.0
,
'num_channel'
:
128
,
'num_residual_block'
:
2
,
'weight_frac'
:
0.5
,
'length_scale'
:
10.
,
},
'structural_violation_loss_weight'
:
1.0
,
'violation_tolerance_factor'
:
12.0
,
'weight'
:
1.0
},
'predicted_lddt'
:
{
'filter_by_resolution'
:
True
,
'max_resolution'
:
3.0
,
'min_resolution'
:
0.1
,
'num_bins'
:
50
,
'num_channels'
:
128
,
'weight'
:
0.01
},
'masked_msa'
:
{
'num_output'
:
23
,
'weight'
:
2.0
},
},
'num_recycle'
:
3
,
'resample_msa_in_recycling'
:
True
},
})
CONFIG_MULTIMER
=
ml_collections
.
ConfigDict
({
'model'
:
{
'embeddings_and_evoformer'
:
{
'evoformer_num_block'
:
48
,
'evoformer'
:
{
'msa_column_attention'
:
{
'dropout_rate'
:
0.0
,
'gating'
:
True
,
'num_head'
:
8
,
'orientation'
:
'per_column'
,
'shared_dropout'
:
True
},
'msa_row_attention_with_pair_bias'
:
{
'dropout_rate'
:
0.15
,
'gating'
:
True
,
'num_head'
:
8
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'msa_transition'
:
{
'dropout_rate'
:
0.0
,
'num_intermediate_factor'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'outer_product_mean'
:
{
'chunk_size'
:
128
,
'dropout_rate'
:
0.0
,
'first'
:
True
,
'num_outer_channel'
:
32
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
'num_intermediate_factor'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_attention_ending_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_column'
,
'shared_dropout'
:
True
},
'triangle_attention_starting_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
128
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
}
},
'extra_msa_channel'
:
64
,
'extra_msa_stack_num_block'
:
4
,
'num_msa'
:
252
,
'num_extra_msa'
:
1152
,
'masked_msa'
:
{
'profile_prob'
:
0.1
,
'replace_fraction'
:
0.15
,
'same_prob'
:
0.1
,
'uniform_prob'
:
0.1
},
'use_chain_relative'
:
True
,
'max_relative_chain'
:
2
,
'max_relative_idx'
:
32
,
'seq_channel'
:
384
,
'msa_channel'
:
256
,
'pair_channel'
:
128
,
'prev_pos'
:
{
'max_bin'
:
20.75
,
'min_bin'
:
3.25
,
'num_bins'
:
15
},
'recycle_features'
:
True
,
'recycle_pos'
:
True
,
'template'
:
{
'attention'
:
{
'gating'
:
False
,
'num_head'
:
4
},
'dgram_features'
:
{
'max_bin'
:
50.75
,
'min_bin'
:
3.25
,
'num_bins'
:
39
},
'enabled'
:
True
,
'max_templates'
:
4
,
'num_channels'
:
64
,
'subbatch_size'
:
128
,
'template_pair_stack'
:
{
'num_block'
:
2
,
'pair_transition'
:
{
'dropout_rate'
:
0.0
,
'num_intermediate_factor'
:
2
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_attention_ending_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_column'
,
'shared_dropout'
:
True
},
'triangle_attention_starting_node'
:
{
'dropout_rate'
:
0.25
,
'gating'
:
True
,
'num_head'
:
4
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_multiplication_incoming'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'kjc,kic->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
},
'triangle_multiplication_outgoing'
:
{
'dropout_rate'
:
0.25
,
'equation'
:
'ikc,jkc->ijc'
,
'num_intermediate_channel'
:
64
,
'orientation'
:
'per_row'
,
'shared_dropout'
:
True
}
}
},
},
'global_config'
:
{
'deterministic'
:
False
,
'multimer_mode'
:
True
,
'subbatch_size'
:
4
,
'use_remat'
:
False
,
'zero_init'
:
True
},
'heads'
:
{
'distogram'
:
{
'first_break'
:
2.3125
,
'last_break'
:
21.6875
,
'num_bins'
:
64
,
'weight'
:
0.3
},
'experimentally_resolved'
:
{
'filter_by_resolution'
:
True
,
'max_resolution'
:
3.0
,
'min_resolution'
:
0.1
,
'weight'
:
0.01
},
'masked_msa'
:
{
'weight'
:
2.0
},
'predicted_aligned_error'
:
{
'filter_by_resolution'
:
True
,
'max_error_bin'
:
31.0
,
'max_resolution'
:
3.0
,
'min_resolution'
:
0.1
,
'num_bins'
:
64
,
'num_channels'
:
128
,
'weight'
:
0.1
},
'predicted_lddt'
:
{
'filter_by_resolution'
:
True
,
'max_resolution'
:
3.0
,
'min_resolution'
:
0.1
,
'num_bins'
:
50
,
'num_channels'
:
128
,
'weight'
:
0.01
},
'structure_module'
:
{
'angle_norm_weight'
:
0.01
,
'chi_weight'
:
0.5
,
'clash_overlap_tolerance'
:
1.5
,
'dropout'
:
0.1
,
'interface_fape'
:
{
'atom_clamp_distance'
:
1000.0
,
'loss_unit_distance'
:
20.0
},
'intra_chain_fape'
:
{
'atom_clamp_distance'
:
10.0
,
'loss_unit_distance'
:
10.0
},
'num_channel'
:
384
,
'num_head'
:
12
,
'num_layer'
:
8
,
'num_layer_in_transition'
:
3
,
'num_point_qk'
:
4
,
'num_point_v'
:
8
,
'num_scalar_qk'
:
16
,
'num_scalar_v'
:
16
,
'position_scale'
:
20.0
,
'sidechain'
:
{
'atom_clamp_distance'
:
10.0
,
'loss_unit_distance'
:
10.0
,
'num_channel'
:
128
,
'num_residual_block'
:
2
,
'weight_frac'
:
0.5
},
'structural_violation_loss_weight'
:
1.0
,
'violation_tolerance_factor'
:
12.0
,
'weight'
:
1.0
}
},
'num_ensemble_eval'
:
1
,
'num_recycle'
:
3
,
'resample_msa_in_recycling'
:
True
}
})
alphafold/model/data.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience functions for reading data."""
import
io
import
os
from
alphafold.model
import
utils
import
haiku
as
hk
import
numpy
as
np
# Internal import (7716).
def
get_model_haiku_params
(
model_name
:
str
,
data_dir
:
str
)
->
hk
.
Params
:
"""Get the Haiku parameters from a model name."""
path
=
os
.
path
.
join
(
data_dir
,
'params'
,
f
'params_
{
model_name
}
.npz'
)
with
open
(
path
,
'rb'
)
as
f
:
params
=
np
.
load
(
io
.
BytesIO
(
f
.
read
()),
allow_pickle
=
False
)
return
utils
.
flat_params_to_haiku
(
params
)
alphafold/model/features.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code to generate processed features."""
import
copy
from
typing
import
List
,
Mapping
,
Tuple
from
alphafold.model.tf
import
input_pipeline
from
alphafold.model.tf
import
proteins_dataset
import
ml_collections
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_data_config
(
config
:
ml_collections
.
ConfigDict
,
num_res
:
int
,
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
"""Makes a data config for the input pipeline."""
cfg
=
copy
.
deepcopy
(
config
.
data
)
feature_names
=
cfg
.
common
.
unsupervised_features
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
with
cfg
.
unlocked
():
cfg
.
eval
.
crop_size
=
num_res
return
cfg
,
feature_names
def
tf_example_to_features
(
tf_example
:
tf
.
train
.
Example
,
config
:
ml_collections
.
ConfigDict
,
random_seed
:
int
=
0
)
->
FeatureDict
:
"""Converts tf_example to numpy feature dictionary."""
num_res
=
int
(
tf_example
.
features
.
feature
[
'seq_length'
].
int64_list
.
value
[
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
num_res
=
num_res
)
if
'deletion_matrix_int'
in
set
(
tf_example
.
features
.
feature
):
deletion_matrix_int
=
(
tf_example
.
features
.
feature
[
'deletion_matrix_int'
].
int64_list
.
value
)
feat
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
map
(
float
,
deletion_matrix_int
)))
tf_example
.
features
.
feature
[
'deletion_matrix'
].
CopyFrom
(
feat
)
del
tf_example
.
features
.
feature
[
'deletion_matrix_int'
]
tf_graph
=
tf
.
Graph
()
with
tf_graph
.
as_default
(),
tf
.
device
(
'/device:CPU:0'
):
tf
.
compat
.
v1
.
set_random_seed
(
random_seed
)
tensor_dict
=
proteins_dataset
.
create_tensor_dict
(
raw_data
=
tf_example
.
SerializeToString
(),
features
=
feature_names
)
processed_batch
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
)
tf_graph
.
finalize
()
with
tf
.
Session
(
graph
=
tf_graph
)
as
sess
:
features
=
sess
.
run
(
processed_batch
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()
if
v
.
dtype
!=
'O'
}
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
random_seed
:
int
=
0
)
->
FeatureDict
:
"""Preprocesses NumPy feature dict using TF pipeline."""
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
num_res
=
num_res
)
if
'deletion_matrix_int'
in
np_example
:
np_example
[
'deletion_matrix'
]
=
(
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
))
tf_graph
=
tf
.
Graph
()
with
tf_graph
.
as_default
(),
tf
.
device
(
'/device:CPU:0'
):
tf
.
compat
.
v1
.
set_random_seed
(
random_seed
)
tensor_dict
=
proteins_dataset
.
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
processed_batch
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
)
tf_graph
.
finalize
()
with
tf
.
Session
(
graph
=
tf_graph
)
as
sess
:
features
=
sess
.
run
(
processed_batch
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()
if
v
.
dtype
!=
'O'
}
alphafold/model/folding.py
deleted
100644 → 0
View file @
a1597f3f
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and utilities for the structure module."""
import
functools
from
typing
import
Dict
from
alphafold.common
import
residue_constants
from
alphafold.model
import
all_atom
from
alphafold.model
import
common_modules
from
alphafold.model
import
prng
from
alphafold.model
import
quat_affine
from
alphafold.model
import
r3
from
alphafold.model
import
utils
import
haiku
as
hk
import
jax
import
jax.numpy
as
jnp
import
ml_collections
import
numpy
as
np
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
class
InvariantPointAttention
(
hk
.
Module
):
"""Invariant Point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Each residue outputs a set of queries and keys as points in their local
reference frame. The attention is then defined as the euclidean distance
between the queries and keys in the global frame.
Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointAttention"
"""
def
__init__
(
self
,
config
,
global_config
,
dist_epsilon
=
1e-8
,
name
=
'invariant_point_attention'
):
"""Initialize.
Args:
config: Structure Module Config
global_config: Global Config of Model.
dist_epsilon: Small value to avoid NaN in distance calculation.
name: Haiku Module name.
"""
super
().
__init__
(
name
=
name
)
self
.
_dist_epsilon
=
dist_epsilon
self
.
_zero_initialize_last
=
global_config
.
zero_init
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
inputs_1d
,
inputs_2d
,
mask
,
affine
):
"""Compute geometry-aware attention.
Given a set of query residues (defined by affines and associated scalar
features), this function computes geometry-aware attention between the
query residues and target residues.
The residues produce points in their local reference frame, which
are converted into the global frame in order to compute attention via
euclidean distance.
Equivalently, the target residues produce points in their local frame to be
used as attention values, which are converted into the query residues'
local frames.
Args:
inputs_1d: (N, C) 1D input embedding that is the basis for the
scalar queries.
inputs_2d: (N, M, C') 2D input embedding, used for biases and values.
mask: (N, 1) mask to indicate which elements of inputs_1d participate
in the attention.
affine: QuatAffine object describing the position and orientation of
every element in inputs_1d.
Returns:
Transformation of the input embedding.
"""
num_residues
,
_
=
inputs_1d
.
shape
# Improve readability by removing a large number of 'self's.
num_head
=
self
.
config
.
num_head
num_scalar_qk
=
self
.
config
.
num_scalar_qk
num_point_qk
=
self
.
config
.
num_point_qk
num_scalar_v
=
self
.
config
.
num_scalar_v
num_point_v
=
self
.
config
.
num_point_v
num_output
=
self
.
config
.
num_channel
assert
num_scalar_qk
>
0
assert
num_point_qk
>
0
assert
num_point_v
>
0
# Construct scalar queries of shape:
# [num_query_residues, num_head, num_points]
q_scalar
=
common_modules
.
Linear
(
num_head
*
num_scalar_qk
,
name
=
'q_scalar'
)(
inputs_1d
)
q_scalar
=
jnp
.
reshape
(
q_scalar
,
[
num_residues
,
num_head
,
num_scalar_qk
])
# Construct scalar keys/values of shape:
# [num_target_residues, num_head, num_points]
kv_scalar
=
common_modules
.
Linear
(
num_head
*
(
num_scalar_v
+
num_scalar_qk
),
name
=
'kv_scalar'
)(
inputs_1d
)
kv_scalar
=
jnp
.
reshape
(
kv_scalar
,
[
num_residues
,
num_head
,
num_scalar_v
+
num_scalar_qk
])
k_scalar
,
v_scalar
=
jnp
.
split
(
kv_scalar
,
[
num_scalar_qk
],
axis
=-
1
)
# Construct query points of shape:
# [num_residues, num_head, num_point_qk]
# First construct query points in local frame.
q_point_local
=
common_modules
.
Linear
(
num_head
*
3
*
num_point_qk
,
name
=
'q_point_local'
)(
inputs_1d
)
q_point_local
=
jnp
.
split
(
q_point_local
,
3
,
axis
=-
1
)
# Project query points into global frame.
q_point_global
=
affine
.
apply_to_point
(
q_point_local
,
extra_dims
=
1
)
# Reshape query point for later use.
q_point
=
[
jnp
.
reshape
(
x
,
[
num_residues
,
num_head
,
num_point_qk
])
for
x
in
q_point_global
]
# Construct key and value points.
# Key points have shape [num_residues, num_head, num_point_qk]
# Value points have shape [num_residues, num_head, num_point_v]
# Construct key and value points in local frame.
kv_point_local
=
common_modules
.
Linear
(
num_head
*
3
*
(
num_point_qk
+
num_point_v
),
name
=
'kv_point_local'
)(
inputs_1d
)
kv_point_local
=
jnp
.
split
(
kv_point_local
,
3
,
axis
=-
1
)
# Project key and value points into global frame.
kv_point_global
=
affine
.
apply_to_point
(
kv_point_local
,
extra_dims
=
1
)
kv_point_global
=
[
jnp
.
reshape
(
x
,
[
num_residues
,
num_head
,
(
num_point_qk
+
num_point_v
)])
for
x
in
kv_point_global
]
# Split key and value points.
k_point
,
v_point
=
list
(
zip
(
*
[
jnp
.
split
(
x
,
[
num_point_qk
,],
axis
=-
1
)
for
x
in
kv_point_global
]))
# We assume that all queries and keys come iid from N(0, 1) distribution
# and compute the variances of the attention logits.
# Each scalar pair (q, k) contributes Var q*k = 1
scalar_variance
=
max
(
num_scalar_qk
,
1
)
*
1.
# Each point pair (q, k) contributes Var [0.5 ||q||^2 - <q, k>] = 9 / 2
point_variance
=
max
(
num_point_qk
,
1
)
*
9.
/
2
# Allocate equal variance to scalar, point and attention 2d parts so that
# the sum is 1.
num_logit_terms
=
3
scalar_weights
=
np
.
sqrt
(
1.0
/
(
num_logit_terms
*
scalar_variance
))
point_weights
=
np
.
sqrt
(
1.0
/
(
num_logit_terms
*
point_variance
))
attention_2d_weights
=
np
.
sqrt
(
1.0
/
(
num_logit_terms
))
# Trainable per-head weights for points.
trainable_point_weights
=
jax
.
nn
.
softplus
(
hk
.
get_parameter
(
'trainable_point_weights'
,
shape
=
[
num_head
],
# softplus^{-1} (1)
init
=
hk
.
initializers
.
Constant
(
np
.
log
(
np
.
exp
(
1.
)
-
1.
))))
point_weights
*=
jnp
.
expand_dims
(
trainable_point_weights
,
axis
=
1
)
v_point
=
[
jnp
.
swapaxes
(
x
,
-
2
,
-
3
)
for
x
in
v_point
]
q_point
=
[
jnp
.
swapaxes
(
x
,
-
2
,
-
3
)
for
x
in
q_point
]
k_point
=
[
jnp
.
swapaxes
(
x
,
-
2
,
-
3
)
for
x
in
k_point
]
dist2
=
[
squared_difference
(
qx
[:,
:,
None
,
:],
kx
[:,
None
,
:,
:])
for
qx
,
kx
in
zip
(
q_point
,
k_point
)
]
dist2
=
sum
(
dist2
)
attn_qk_point
=
-
0.5
*
jnp
.
sum
(
point_weights
[:,
None
,
None
,
:]
*
dist2
,
axis
=-
1
)
v
=
jnp
.
swapaxes
(
v_scalar
,
-
2
,
-
3
)
q
=
jnp
.
swapaxes
(
scalar_weights
*
q_scalar
,
-
2
,
-
3
)
k
=
jnp
.
swapaxes
(
k_scalar
,
-
2
,
-
3
)
attn_qk_scalar
=
jnp
.
matmul
(
q
,
jnp
.
swapaxes
(
k
,
-
2
,
-
1
))
attn_logits
=
attn_qk_scalar
+
attn_qk_point
attention_2d
=
common_modules
.
Linear
(
num_head
,
name
=
'attention_2d'
)(
inputs_2d
)
attention_2d
=
jnp
.
transpose
(
attention_2d
,
[
2
,
0
,
1
])
attention_2d
=
attention_2d_weights
*
attention_2d
attn_logits
+=
attention_2d
mask_2d
=
mask
*
jnp
.
swapaxes
(
mask
,
-
1
,
-
2
)
attn_logits
-=
1e5
*
(
1.
-
mask_2d
)
# [num_head, num_query_residues, num_target_residues]
attn
=
jax
.
nn
.
softmax
(
attn_logits
)
# [num_head, num_query_residues, num_head * num_scalar_v]
result_scalar
=
jnp
.
matmul
(
attn
,
v
)
# For point result, implement matmul manually so that it will be a float32
# on TPU. This is equivalent to
# result_point_global = [jnp.einsum('bhqk,bhkc->bhqc', attn, vx)
# for vx in v_point]
# but on the TPU, doing the multiply and reduce_sum ensures the
# computation happens in float32 instead of bfloat16.
result_point_global
=
[
jnp
.
sum
(
attn
[:,
:,
:,
None
]
*
vx
[:,
None
,
:,
:],
axis
=-
2
)
for
vx
in
v_point
]
# [num_query_residues, num_head, num_head * num_(scalar|point)_v]
result_scalar
=
jnp
.
swapaxes
(
result_scalar
,
-
2
,
-
3
)
result_point_global
=
[
jnp
.
swapaxes
(
x
,
-
2
,
-
3
)
for
x
in
result_point_global
]
# Features used in the linear output projection. Should have the size
# [num_query_residues, ?]
output_features
=
[]
result_scalar
=
jnp
.
reshape
(
result_scalar
,
[
num_residues
,
num_head
*
num_scalar_v
])
output_features
.
append
(
result_scalar
)
result_point_global
=
[
jnp
.
reshape
(
r
,
[
num_residues
,
num_head
*
num_point_v
])
for
r
in
result_point_global
]
result_point_local
=
affine
.
invert_point
(
result_point_global
,
extra_dims
=
1
)
output_features
.
extend
(
result_point_local
)
output_features
.
append
(
jnp
.
sqrt
(
self
.
_dist_epsilon
+
jnp
.
square
(
result_point_local
[
0
])
+
jnp
.
square
(
result_point_local
[
1
])
+
jnp
.
square
(
result_point_local
[
2
])))
# Dimensions: h = heads, i and j = residues,
# c = inputs_2d channels
# Contraction happens over the second residue dimension, similarly to how
# the usual attention is performed.
result_attention_over_2d
=
jnp
.
einsum
(
'hij, ijc->ihc'
,
attn
,
inputs_2d
)
num_out
=
num_head
*
result_attention_over_2d
.
shape
[
-
1
]
output_features
.
append
(
jnp
.
reshape
(
result_attention_over_2d
,
[
num_residues
,
num_out
]))
final_init
=
'zeros'
if
self
.
_zero_initialize_last
else
'linear'
final_act
=
jnp
.
concatenate
(
output_features
,
axis
=-
1
)
return
common_modules
.
Linear
(
num_output
,
initializer
=
final_init
,
name
=
'output_projection'
)(
final_act
)
class
FoldIteration
(
hk
.
Module
):
"""A single iteration of the main structure module loop.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" lines 6-21
First, each residue attends to all residues using InvariantPointAttention.
Then, we apply transition layers to update the hidden representations.
Finally, we use the hidden representations to produce an update to the
affine of each residue.
"""
def
__init__
(
self
,
config
,
global_config
,
name
=
'fold_iteration'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
activations
,
sequence_mask
,
update_affine
,
is_training
,
initial_act
,
safe_key
=
None
,
static_feat_2d
=
None
,
aatype
=
None
):
c
=
self
.
config
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
def
safe_dropout_fn
(
tensor
,
safe_key
):
return
prng
.
safe_dropout
(
tensor
=
tensor
,
safe_key
=
safe_key
,
rate
=
c
.
dropout
,
is_deterministic
=
self
.
global_config
.
deterministic
,
is_training
=
is_training
)
affine
=
quat_affine
.
QuatAffine
.
from_tensor
(
activations
[
'affine'
])
act
=
activations
[
'act'
]
attention_module
=
InvariantPointAttention
(
self
.
config
,
self
.
global_config
)
# Attention
attn
=
attention_module
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
sequence_mask
,
affine
=
affine
)
act
+=
attn
safe_key
,
*
sub_keys
=
safe_key
.
split
(
3
)
sub_keys
=
iter
(
sub_keys
)
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'attention_layer_norm'
)(
act
)
final_init
=
'zeros'
if
self
.
global_config
.
zero_init
else
'linear'
# Transition
input_act
=
act
for
i
in
range
(
c
.
num_layer_in_transition
):
init
=
'relu'
if
i
<
c
.
num_layer_in_transition
-
1
else
final_init
act
=
common_modules
.
Linear
(
c
.
num_channel
,
initializer
=
init
,
name
=
'transition'
)(
act
)
if
i
<
c
.
num_layer_in_transition
-
1
:
act
=
jax
.
nn
.
relu
(
act
)
act
+=
input_act
act
=
safe_dropout_fn
(
act
,
next
(
sub_keys
))
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'transition_layer_norm'
)(
act
)
if
update_affine
:
# This block corresponds to
# Jumper et al. (2021) Alg. 23 "Backbone update"
affine_update_size
=
6
# Affine update
affine_update
=
common_modules
.
Linear
(
affine_update_size
,
initializer
=
final_init
,
name
=
'affine_update'
)(
act
)
affine
=
affine
.
pre_compose
(
affine_update
)
sc
=
MultiRigidSidechain
(
c
.
sidechain
,
self
.
global_config
)(
affine
.
scale_translation
(
c
.
position_scale
),
[
act
,
initial_act
],
aatype
)
outputs
=
{
'affine'
:
affine
.
to_tensor
(),
'sc'
:
sc
}
affine
=
affine
.
apply_rotation_tensor_fn
(
jax
.
lax
.
stop_gradient
)
new_activations
=
{
'act'
:
act
,
'affine'
:
affine
.
to_tensor
()
}
return
new_activations
,
outputs
def
generate_affines
(
representations
,
batch
,
config
,
global_config
,
is_training
,
safe_key
):
"""Generate predicted affines for a single chain.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
This is the main part of the structure module - it iteratively applies
folding to produce a set of predicted residue positions.
Args:
representations: Representations dictionary.
batch: Batch dictionary.
config: Config for the structure module.
global_config: Global config.
is_training: Whether the model is being trained.
safe_key: A prng.SafeKey object that wraps a PRNG key.
Returns:
A dictionary containing residue affines and sidechain positions.
"""
c
=
config
sequence_mask
=
batch
[
'seq_mask'
][:,
None
]
act
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'single_layer_norm'
)(
representations
[
'single'
])
initial_act
=
act
act
=
common_modules
.
Linear
(
c
.
num_channel
,
name
=
'initial_projection'
)(
act
)
affine
=
generate_new_affine
(
sequence_mask
)
fold_iteration
=
FoldIteration
(
c
,
global_config
,
name
=
'fold_iteration'
)
assert
len
(
batch
[
'seq_mask'
].
shape
)
==
1
activations
=
{
'act'
:
act
,
'affine'
:
affine
.
to_tensor
(),
}
act_2d
=
hk
.
LayerNorm
(
axis
=
[
-
1
],
create_scale
=
True
,
create_offset
=
True
,
name
=
'pair_layer_norm'
)(
representations
[
'pair'
])
outputs
=
[]
safe_keys
=
safe_key
.
split
(
c
.
num_layer
)
for
sub_key
in
safe_keys
:
activations
,
output
=
fold_iteration
(
activations
,
initial_act
=
initial_act
,
static_feat_2d
=
act_2d
,
safe_key
=
sub_key
,
sequence_mask
=
sequence_mask
,
update_affine
=
True
,
is_training
=
is_training
,
aatype
=
batch
[
'aatype'
])
outputs
.
append
(
output
)
output
=
jax
.
tree_map
(
lambda
*
x
:
jnp
.
stack
(
x
),
*
outputs
)
# Include the activations in the output dict for use by the LDDT-Head.
output
[
'act'
]
=
activations
[
'act'
]
return
output
class
StructureModule
(
hk
.
Module
):
"""StructureModule as a network head.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule"
"""
def
__init__
(
self
,
config
,
global_config
,
compute_loss
=
True
,
name
=
'structure_module'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
self
.
compute_loss
=
compute_loss
def
__call__
(
self
,
representations
,
batch
,
is_training
,
safe_key
=
None
):
c
=
self
.
config
ret
=
{}
if
safe_key
is
None
:
safe_key
=
prng
.
SafeKey
(
hk
.
next_rng_key
())
output
=
generate_affines
(
representations
=
representations
,
batch
=
batch
,
config
=
self
.
config
,
global_config
=
self
.
global_config
,
is_training
=
is_training
,
safe_key
=
safe_key
)
ret
[
'representations'
]
=
{
'structure_module'
:
output
[
'act'
]}
ret
[
'traj'
]
=
output
[
'affine'
]
*
jnp
.
array
([
1.
]
*
4
+
[
c
.
position_scale
]
*
3
)
ret
[
'sidechains'
]
=
output
[
'sc'
]
atom14_pred_positions
=
r3
.
vecs_to_tensor
(
output
[
'sc'
][
'atom_pos'
])[
-
1
]
ret
[
'final_atom14_positions'
]
=
atom14_pred_positions
# (N, 14, 3)
ret
[
'final_atom14_mask'
]
=
batch
[
'atom14_atom_exists'
]
# (N, 14)
atom37_pred_positions
=
all_atom
.
atom14_to_atom37
(
atom14_pred_positions
,
batch
)
atom37_pred_positions
*=
batch
[
'atom37_atom_exists'
][:,
:,
None
]
ret
[
'final_atom_positions'
]
=
atom37_pred_positions
# (N, 37, 3)
ret
[
'final_atom_mask'
]
=
batch
[
'atom37_atom_exists'
]
# (N, 37)
ret
[
'final_affines'
]
=
ret
[
'traj'
][
-
1
]
if
self
.
compute_loss
:
return
ret
else
:
no_loss_features
=
[
'final_atom_positions'
,
'final_atom_mask'
,
'representations'
]
no_loss_ret
=
{
k
:
ret
[
k
]
for
k
in
no_loss_features
}
return
no_loss_ret
def
loss
(
self
,
value
,
batch
):
ret
=
{
'loss'
:
0.
}
ret
[
'metrics'
]
=
{}
# If requested, compute in-graph metrics.
if
self
.
config
.
compute_in_graph_metrics
:
atom14_pred_positions
=
value
[
'final_atom14_positions'
]
# Compute renaming and violations.
value
.
update
(
compute_renamed_ground_truth
(
batch
,
atom14_pred_positions
))
value
[
'violations'
]
=
find_structural_violations
(
batch
,
atom14_pred_positions
,
self
.
config
)
# Several violation metrics:
violation_metrics
=
compute_violation_metrics
(
batch
=
batch
,
atom14_pred_positions
=
atom14_pred_positions
,
violations
=
value
[
'violations'
])
ret
[
'metrics'
].
update
(
violation_metrics
)
backbone_loss
(
ret
,
batch
,
value
,
self
.
config
)
if
'renamed_atom14_gt_positions'
not
in
value
:
value
.
update
(
compute_renamed_ground_truth
(
batch
,
value
[
'final_atom14_positions'
]))
sc_loss
=
sidechain_loss
(
batch
,
value
,
self
.
config
)
ret
[
'loss'
]
=
((
1
-
self
.
config
.
sidechain
.
weight_frac
)
*
ret
[
'loss'
]
+
self
.
config
.
sidechain
.
weight_frac
*
sc_loss
[
'loss'
])
ret
[
'sidechain_fape'
]
=
sc_loss
[
'fape'
]
supervised_chi_loss
(
ret
,
batch
,
value
,
self
.
config
)
if
self
.
config
.
structural_violation_loss_weight
:
if
'violations'
not
in
value
:
value
[
'violations'
]
=
find_structural_violations
(
batch
,
value
[
'final_atom14_positions'
],
self
.
config
)
structural_violation_loss
(
ret
,
batch
,
value
,
self
.
config
)
return
ret
def
compute_renamed_ground_truth
(
batch
:
Dict
[
str
,
jnp
.
ndarray
],
atom14_pred_positions
:
jnp
.
ndarray
,
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Find optimal renaming of ground truth based on the predicted positions.
Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Shape (N).
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
renaming swaps.
* atom14_gt_exists: Mask for which atoms exist in ground truth.
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
after renaming.
* atom14_atom_exists: Mask for whether each atom is part of the given
amino acid type.
atom14_pred_positions: Array of atom positions in global frame with shape
(N, 14, 3).
Returns:
Dictionary containing:
alt_naming_is_better: Array with 1.0 where alternative swap is better.
renamed_atom14_gt_positions: Array of optimal ground truth positions
after renaming swaps are performed.
renamed_atom14_gt_exists: Mask after renaming swap is performed.
"""
alt_naming_is_better
=
all_atom
.
find_optimal_renaming
(
atom14_gt_positions
=
batch
[
'atom14_gt_positions'
],
atom14_alt_gt_positions
=
batch
[
'atom14_alt_gt_positions'
],
atom14_atom_is_ambiguous
=
batch
[
'atom14_atom_is_ambiguous'
],
atom14_gt_exists
=
batch
[
'atom14_gt_exists'
],
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
'atom14_atom_exists'
])
renamed_atom14_gt_positions
=
(
(
1.
-
alt_naming_is_better
[:,
None
,
None
])
*
batch
[
'atom14_gt_positions'
]
+
alt_naming_is_better
[:,
None
,
None
]
*
batch
[
'atom14_alt_gt_positions'
])
renamed_atom14_gt_mask
=
(
(
1.
-
alt_naming_is_better
[:,
None
])
*
batch
[
'atom14_gt_exists'
]
+
alt_naming_is_better
[:,
None
]
*
batch
[
'atom14_alt_gt_exists'
])
return
{
'alt_naming_is_better'
:
alt_naming_is_better
,
# (N)
'renamed_atom14_gt_positions'
:
renamed_atom14_gt_positions
,
# (N, 14, 3)
'renamed_atom14_gt_exists'
:
renamed_atom14_gt_mask
,
# (N, 14)
}
def
backbone_loss
(
ret
,
batch
,
value
,
config
):
"""Backbone FAPE Loss.
Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" line 17
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'backbone_affine_tensor',
'backbone_affine_mask'.
value: Dictionary containing structure module output, needs to contain
'traj', a trajectory of rigids.
config: Configuration of loss, should contain 'fape.clamp_distance' and
'fape.loss_unit_distance'.
"""
affine_trajectory
=
quat_affine
.
QuatAffine
.
from_tensor
(
value
[
'traj'
])
rigid_trajectory
=
r3
.
rigids_from_quataffine
(
affine_trajectory
)
gt_affine
=
quat_affine
.
QuatAffine
.
from_tensor
(
batch
[
'backbone_affine_tensor'
])
gt_rigid
=
r3
.
rigids_from_quataffine
(
gt_affine
)
backbone_mask
=
batch
[
'backbone_affine_mask'
]
fape_loss_fn
=
functools
.
partial
(
all_atom
.
frame_aligned_point_error
,
l1_clamp_distance
=
config
.
fape
.
clamp_distance
,
length_scale
=
config
.
fape
.
loss_unit_distance
)
fape_loss_fn
=
jax
.
vmap
(
fape_loss_fn
,
(
0
,
None
,
None
,
0
,
None
,
None
))
fape_loss
=
fape_loss_fn
(
rigid_trajectory
,
gt_rigid
,
backbone_mask
,
rigid_trajectory
.
trans
,
gt_rigid
.
trans
,
backbone_mask
)
if
'use_clamped_fape'
in
batch
:
# Jumper et al. (2021) Suppl. Sec. 1.11.5 "Loss clamping details"
use_clamped_fape
=
jnp
.
asarray
(
batch
[
'use_clamped_fape'
],
jnp
.
float32
)
unclamped_fape_loss_fn
=
functools
.
partial
(
all_atom
.
frame_aligned_point_error
,
l1_clamp_distance
=
None
,
length_scale
=
config
.
fape
.
loss_unit_distance
)
unclamped_fape_loss_fn
=
jax
.
vmap
(
unclamped_fape_loss_fn
,
(
0
,
None
,
None
,
0
,
None
,
None
))
fape_loss_unclamped
=
unclamped_fape_loss_fn
(
rigid_trajectory
,
gt_rigid
,
backbone_mask
,
rigid_trajectory
.
trans
,
gt_rigid
.
trans
,
backbone_mask
)
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
fape_loss_unclamped
*
(
1
-
use_clamped_fape
))
ret
[
'fape'
]
=
fape_loss
[
-
1
]
ret
[
'loss'
]
+=
jnp
.
mean
(
fape_loss
)
def
sidechain_loss
(
batch
,
value
,
config
):
"""All Atom FAPE Loss using renamed rigids."""
# Rename Frames
# Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" line 7
alt_naming_is_better
=
value
[
'alt_naming_is_better'
]
renamed_gt_frames
=
(
(
1.
-
alt_naming_is_better
[:,
None
,
None
])
*
batch
[
'rigidgroups_gt_frames'
]
+
alt_naming_is_better
[:,
None
,
None
]
*
batch
[
'rigidgroups_alt_gt_frames'
])
flat_gt_frames
=
r3
.
rigids_from_tensor_flat12
(
jnp
.
reshape
(
renamed_gt_frames
,
[
-
1
,
12
]))
flat_frames_mask
=
jnp
.
reshape
(
batch
[
'rigidgroups_gt_exists'
],
[
-
1
])
flat_gt_positions
=
r3
.
vecs_from_tensor
(
jnp
.
reshape
(
value
[
'renamed_atom14_gt_positions'
],
[
-
1
,
3
]))
flat_positions_mask
=
jnp
.
reshape
(
value
[
'renamed_atom14_gt_exists'
],
[
-
1
])
# Compute frame_aligned_point_error score for the final layer.
pred_frames
=
value
[
'sidechains'
][
'frames'
]
pred_positions
=
value
[
'sidechains'
][
'atom_pos'
]
def
_slice_last_layer_and_flatten
(
x
):
return
jnp
.
reshape
(
x
[
-
1
],
[
-
1
])
flat_pred_frames
=
jax
.
tree_map
(
_slice_last_layer_and_flatten
,
pred_frames
)
flat_pred_positions
=
jax
.
tree_map
(
_slice_last_layer_and_flatten
,
pred_positions
)
# FAPE Loss on sidechains
fape
=
all_atom
.
frame_aligned_point_error
(
pred_frames
=
flat_pred_frames
,
target_frames
=
flat_gt_frames
,
frames_mask
=
flat_frames_mask
,
pred_positions
=
flat_pred_positions
,
target_positions
=
flat_gt_positions
,
positions_mask
=
flat_positions_mask
,
l1_clamp_distance
=
config
.
sidechain
.
atom_clamp_distance
,
length_scale
=
config
.
sidechain
.
length_scale
)
return
{
'fape'
:
fape
,
'loss'
:
fape
}
def
structural_violation_loss
(
ret
,
batch
,
value
,
config
):
"""Computes loss for structural violations."""
assert
config
.
sidechain
.
weight_frac
# Put all violation losses together to one large loss.
violations
=
value
[
'violations'
]
num_atoms
=
jnp
.
sum
(
batch
[
'atom14_atom_exists'
]).
astype
(
jnp
.
float32
)
ret
[
'loss'
]
+=
(
config
.
structural_violation_loss_weight
*
(
violations
[
'between_residues'
][
'bonds_c_n_loss_mean'
]
+
violations
[
'between_residues'
][
'angles_ca_c_n_loss_mean'
]
+
violations
[
'between_residues'
][
'angles_c_n_ca_loss_mean'
]
+
jnp
.
sum
(
violations
[
'between_residues'
][
'clashes_per_atom_loss_sum'
]
+
violations
[
'within_residues'
][
'per_atom_loss_sum'
])
/
(
1e-6
+
num_atoms
)))
def
find_structural_violations
(
batch
:
Dict
[
str
,
jnp
.
ndarray
],
atom14_pred_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
config
:
ml_collections
.
ConfigDict
):
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations
=
all_atom
.
between_residue_bond_loss
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
'atom14_atom_exists'
].
astype
(
jnp
.
float32
),
residue_index
=
batch
[
'residue_index'
].
astype
(
jnp
.
float32
),
aatype
=
batch
[
'aatype'
],
tolerance_factor_soft
=
config
.
violation_tolerance_factor
,
tolerance_factor_hard
=
config
.
violation_tolerance_factor
)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius
=
jnp
.
array
([
residue_constants
.
van_der_waals_radius
[
name
[
0
]]
for
name
in
residue_constants
.
atom_types
])
atom14_atom_radius
=
batch
[
'atom14_atom_exists'
]
*
utils
.
batched_gather
(
atomtype_radius
,
batch
[
'residx_atom14_to_atom37'
])
# Compute the between residue clash loss.
between_residue_clashes
=
all_atom
.
between_residue_clash_loss
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
'atom14_atom_exists'
],
atom14_atom_radius
=
atom14_atom_radius
,
residue_index
=
batch
[
'residue_index'
],
overlap_tolerance_soft
=
config
.
clash_overlap_tolerance
,
overlap_tolerance_hard
=
config
.
clash_overlap_tolerance
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds
=
residue_constants
.
make_atom14_dists_bounds
(
overlap_tolerance
=
config
.
clash_overlap_tolerance
,
bond_length_tolerance_factor
=
config
.
violation_tolerance_factor
)
atom14_dists_lower_bound
=
utils
.
batched_gather
(
restype_atom14_bounds
[
'lower_bound'
],
batch
[
'aatype'
])
atom14_dists_upper_bound
=
utils
.
batched_gather
(
restype_atom14_bounds
[
'upper_bound'
],
batch
[
'aatype'
])
within_residue_violations
=
all_atom
.
within_residue_violations
(
atom14_pred_positions
=
atom14_pred_positions
,
atom14_atom_exists
=
batch
[
'atom14_atom_exists'
],
atom14_dists_lower_bound
=
atom14_dists_lower_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask
=
jnp
.
max
(
jnp
.
stack
([
connection_violations
[
'per_residue_violation_mask'
],
jnp
.
max
(
between_residue_clashes
[
'per_atom_clash_mask'
],
axis
=-
1
),
jnp
.
max
(
within_residue_violations
[
'per_atom_violations'
],
axis
=-
1
)]),
axis
=
0
)
return
{
'between_residues'
:
{
'bonds_c_n_loss_mean'
:
connection_violations
[
'c_n_loss_mean'
],
# ()
'angles_ca_c_n_loss_mean'
:
connection_violations
[
'ca_c_n_loss_mean'
],
# ()
'angles_c_n_ca_loss_mean'
:
connection_violations
[
'c_n_ca_loss_mean'
],
# ()
'connections_per_residue_loss_sum'
:
connection_violations
[
'per_residue_loss_sum'
],
# (N)
'connections_per_residue_violation_mask'
:
connection_violations
[
'per_residue_violation_mask'
],
# (N)
'clashes_mean_loss'
:
between_residue_clashes
[
'mean_loss'
],
# ()
'clashes_per_atom_loss_sum'
:
between_residue_clashes
[
'per_atom_loss_sum'
],
# (N, 14)
'clashes_per_atom_clash_mask'
:
between_residue_clashes
[
'per_atom_clash_mask'
],
# (N, 14)
},
'within_residues'
:
{
'per_atom_loss_sum'
:
within_residue_violations
[
'per_atom_loss_sum'
],
# (N, 14)
'per_atom_violations'
:
within_residue_violations
[
'per_atom_violations'
],
# (N, 14),
},
'total_per_residue_violations_mask'
:
per_residue_violations_mask
,
# (N)
}
def
compute_violation_metrics
(
batch
:
Dict
[
str
,
jnp
.
ndarray
],
atom14_pred_positions
:
jnp
.
ndarray
,
# (N, 14, 3)
violations
:
Dict
[
str
,
jnp
.
ndarray
],
)
->
Dict
[
str
,
jnp
.
ndarray
]:
"""Compute several metrics to assess the structural violations."""
ret
=
{}
extreme_ca_ca_violations
=
all_atom
.
extreme_ca_ca_distance_violations
(
pred_atom_positions
=
atom14_pred_positions
,
pred_atom_mask
=
batch
[
'atom14_atom_exists'
].
astype
(
jnp
.
float32
),
residue_index
=
batch
[
'residue_index'
].
astype
(
jnp
.
float32
))
ret
[
'violations_extreme_ca_ca_distance'
]
=
extreme_ca_ca_violations
ret
[
'violations_between_residue_bond'
]
=
utils
.
mask_mean
(
mask
=
batch
[
'seq_mask'
],
value
=
violations
[
'between_residues'
][
'connections_per_residue_violation_mask'
])
ret
[
'violations_between_residue_clash'
]
=
utils
.
mask_mean
(
mask
=
batch
[
'seq_mask'
],
value
=
jnp
.
max
(
violations
[
'between_residues'
][
'clashes_per_atom_clash_mask'
],
axis
=-
1
))
ret
[
'violations_within_residue'
]
=
utils
.
mask_mean
(
mask
=
batch
[
'seq_mask'
],
value
=
jnp
.
max
(
violations
[
'within_residues'
][
'per_atom_violations'
],
axis
=-
1
))
ret
[
'violations_per_residue'
]
=
utils
.
mask_mean
(
mask
=
batch
[
'seq_mask'
],
value
=
violations
[
'total_per_residue_violations_mask'
])
return
ret
def
supervised_chi_loss
(
ret
,
batch
,
value
,
config
):
"""Computes loss for direct chi angle supervision.
Jumper et al. (2021) Suppl. Alg. 27 "torsionAngleLoss"
Args:
ret: Dictionary to write outputs into, needs to contain 'loss'.
batch: Batch, needs to contain 'seq_mask', 'chi_mask', 'chi_angles'.
value: Dictionary containing structure module output, needs to contain
value['sidechains']['angles_sin_cos'] for angles and
value['sidechains']['unnormalized_angles_sin_cos'] for unnormalized
angles.
config: Configuration of loss, should contain 'chi_weight' and
'angle_norm_weight', 'angle_norm_weight' scales angle norm term,
'chi_weight' scales torsion term.
"""
eps
=
1e-6
sequence_mask
=
batch
[
'seq_mask'
]
num_res
=
sequence_mask
.
shape
[
0
]
chi_mask
=
batch
[
'chi_mask'
].
astype
(
jnp
.
float32
)
pred_angles
=
jnp
.
reshape
(
value
[
'sidechains'
][
'angles_sin_cos'
],
[
-
1
,
num_res
,
7
,
2
])
pred_angles
=
pred_angles
[:,
:,
3
:]
residue_type_one_hot
=
jax
.
nn
.
one_hot
(
batch
[
'aatype'
],
residue_constants
.
restype_num
+
1
,
dtype
=
jnp
.
float32
)[
None
]
chi_pi_periodic
=
jnp
.
einsum
(
'ijk, kl->ijl'
,
residue_type_one_hot
,
jnp
.
asarray
(
residue_constants
.
chi_pi_periodic
))
true_chi
=
batch
[
'chi_angles'
][
None
]
sin_true_chi
=
jnp
.
sin
(
true_chi
)
cos_true_chi
=
jnp
.
cos
(
true_chi
)
sin_cos_true_chi
=
jnp
.
stack
([
sin_true_chi
,
cos_true_chi
],
axis
=-
1
)
# This is -1 if chi is pi-periodic and +1 if it's 2pi-periodic
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
)[...,
None
]
sin_cos_true_chi_shifted
=
shifted_mask
*
sin_cos_true_chi
sq_chi_error
=
jnp
.
sum
(
squared_difference
(
sin_cos_true_chi
,
pred_angles
),
-
1
)
sq_chi_error_shifted
=
jnp
.
sum
(
squared_difference
(
sin_cos_true_chi_shifted
,
pred_angles
),
-
1
)
sq_chi_error
=
jnp
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_loss
=
utils
.
mask_mean
(
mask
=
chi_mask
[
None
],
value
=
sq_chi_error
)
ret
[
'chi_loss'
]
=
sq_chi_loss
ret
[
'loss'
]
+=
config
.
chi_weight
*
sq_chi_loss
unnormed_angles
=
jnp
.
reshape
(
value
[
'sidechains'
][
'unnormalized_angles_sin_cos'
],
[
-
1
,
num_res
,
7
,
2
])
angle_norm
=
jnp
.
sqrt
(
jnp
.
sum
(
jnp
.
square
(
unnormed_angles
),
axis
=-
1
)
+
eps
)
norm_error
=
jnp
.
abs
(
angle_norm
-
1.
)
angle_norm_loss
=
utils
.
mask_mean
(
mask
=
sequence_mask
[
None
,
:,
None
],
value
=
norm_error
)
ret
[
'angle_norm_loss'
]
=
angle_norm_loss
ret
[
'loss'
]
+=
config
.
angle_norm_weight
*
angle_norm_loss
def
generate_new_affine
(
sequence_mask
):
num_residues
,
_
=
sequence_mask
.
shape
quaternion
=
jnp
.
tile
(
jnp
.
reshape
(
jnp
.
asarray
([
1.
,
0.
,
0.
,
0.
]),
[
1
,
4
]),
[
num_residues
,
1
])
translation
=
jnp
.
zeros
([
num_residues
,
3
])
return
quat_affine
.
QuatAffine
(
quaternion
,
translation
,
unstack_inputs
=
True
)
def
l2_normalize
(
x
,
axis
=-
1
,
epsilon
=
1e-12
):
return
x
/
jnp
.
sqrt
(
jnp
.
maximum
(
jnp
.
sum
(
x
**
2
,
axis
=
axis
,
keepdims
=
True
),
epsilon
))
class
MultiRigidSidechain
(
hk
.
Module
):
"""Class to make side chain atoms."""
def
__init__
(
self
,
config
,
global_config
,
name
=
'rigid_sidechain'
):
super
().
__init__
(
name
=
name
)
self
.
config
=
config
self
.
global_config
=
global_config
def
__call__
(
self
,
affine
,
representations_list
,
aatype
):
"""Predict side chains using multi-rigid representations.
Args:
affine: The affines for each residue (translations in angstroms).
representations_list: A list of activations to predict side chains from.
aatype: Amino acid types.
Returns:
Dict containing atom positions and frames (in angstroms).
"""
act
=
[
common_modules
.
Linear
(
# pylint: disable=g-complex-comprehension
self
.
config
.
num_channel
,
name
=
'input_projection'
)(
jax
.
nn
.
relu
(
x
))
for
x
in
representations_list
]
# Sum the activation list (equivalent to concat then Linear).
act
=
sum
(
act
)
final_init
=
'zeros'
if
self
.
global_config
.
zero_init
else
'linear'
# Mapping with some residual blocks.
for
_
in
range
(
self
.
config
.
num_residual_block
):
old_act
=
act
act
=
common_modules
.
Linear
(
self
.
config
.
num_channel
,
initializer
=
'relu'
,
name
=
'resblock1'
)(
jax
.
nn
.
relu
(
act
))
act
=
common_modules
.
Linear
(
self
.
config
.
num_channel
,
initializer
=
final_init
,
name
=
'resblock2'
)(
jax
.
nn
.
relu
(
act
))
act
+=
old_act
# Map activations to torsion angles. Shape: (num_res, 14).
num_res
=
act
.
shape
[
0
]
unnormalized_angles
=
common_modules
.
Linear
(
14
,
name
=
'unnormalized_angles'
)(
jax
.
nn
.
relu
(
act
))
unnormalized_angles
=
jnp
.
reshape
(
unnormalized_angles
,
[
num_res
,
7
,
2
])
angles
=
l2_normalize
(
unnormalized_angles
,
axis
=-
1
)
outputs
=
{
'angles_sin_cos'
:
angles
,
# jnp.ndarray (N, 7, 2)
'unnormalized_angles_sin_cos'
:
unnormalized_angles
,
# jnp.ndarray (N, 7, 2)
}
# Map torsion angles to frames.
backb_to_global
=
r3
.
rigids_from_quataffine
(
affine
)
# Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates"
# r3.Rigids with shape (N, 8).
all_frames_to_global
=
all_atom
.
torsion_angles_to_frames
(
aatype
,
backb_to_global
,
angles
)
# Use frames and literature positions to create the final atom coordinates.
# r3.Vecs with shape (N, 14).
pred_positions
=
all_atom
.
frames_and_literature_positions_to_atom14_pos
(
aatype
,
all_frames_to_global
)
outputs
.
update
({
'atom_pos'
:
pred_positions
,
# r3.Vecs (N, 14)
'frames'
:
all_frames_to_global
,
# r3.Rigids (N, 8)
})
return
outputs
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment