Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4bd1b4d5
Commit
4bd1b4d5
authored
Apr 28, 2022
by
Gustaf Ahdritz
Browse files
Work on multimer continues
parent
54164fe8
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2327 additions
and
689 deletions
+2327
-689
openfold/__init__.py
openfold/__init__.py
+1
-0
openfold/config.py
openfold/config.py
+84
-6
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+152
-145
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+16
-4
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+303
-0
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+21
-7
openfold/data/feature_processing_multimer.py
openfold/data/feature_processing_multimer.py
+13
-9
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+135
-0
openfold/data/msa_identifiers.py
openfold/data/msa_identifiers.py
+2
-3
openfold/data/msa_pairing.py
openfold/data/msa_pairing.py
+69
-212
openfold/data/parsers.py
openfold/data/parsers.py
+20
-3
openfold/data/templates.py
openfold/data/templates.py
+82
-60
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+6
-4
openfold/data/tools/hmmsearch.py
openfold/data/tools/hmmsearch.py
+22
-19
openfold/model/embedders.py
openfold/model/embedders.py
+583
-13
openfold/model/model.py
openfold/model/model.py
+79
-111
openfold/model/structure_module.py
openfold/model/structure_module.py
+215
-92
openfold/np/protein.py
openfold/np/protein.py
+1
-1
openfold/utils/all_atom_multimer.py
openfold/utils/all_atom_multimer.py
+493
-0
openfold/utils/argparse_utils.py
openfold/utils/argparse_utils.py
+30
-0
No files found.
openfold/__init__.py
View file @
4bd1b4d5
from
.
import
model
from
.
import
model
from
.
import
utils
from
.
import
utils
from
.
import
data
from
.
import
np
from
.
import
np
from
.
import
resources
from
.
import
resources
...
...
openfold/config.py
View file @
4bd1b4d5
...
@@ -75,7 +75,17 @@ def model_config(name, train=False, low_prec=False):
...
@@ -75,7 +75,17 @@ def model_config(name, train=False, low_prec=False):
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
model
.
update
(
multimer_model_config_update
)
c
.
globals
.
is_multimer
=
True
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
])
else
:
else
:
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
...
@@ -276,6 +286,7 @@ config = mlc.ConfigDict(
...
@@ -276,6 +286,7 @@ config = mlc.ConfigDict(
"c_e"
:
c_e
,
"c_e"
:
c_e
,
"c_s"
:
c_s
,
"c_s"
:
c_s
,
"eps"
:
eps
,
"eps"
:
eps
,
"is_multimer"
:
False
,
},
},
"model"
:
{
"model"
:
{
"_mask_trans"
:
False
,
"_mask_trans"
:
False
,
...
@@ -335,6 +346,7 @@ config = mlc.ConfigDict(
...
@@ -335,6 +346,7 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
# 1e-6,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
False
,
},
},
"extra_msa"
:
{
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"extra_msa_embedder"
:
{
...
@@ -496,10 +508,76 @@ config = mlc.ConfigDict(
...
@@ -496,10 +508,76 @@ config = mlc.ConfigDict(
}
}
)
)
multimer_model_config_update
=
mlc
.
ConfigDict
(
multimer_model_config_update
=
{
"relative_encoding"
:
{
"input_embedder"
:
{
"enabled"
:
True
,
"tf_dim"
:
21
,
"msa_dim"
:
49
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"max_relative_idx"
:
32
,
}
"use_chain_relative"
:
True
,
)
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_pair_embedder"
:
{
"c_z"
:
c_z
,
"c_out"
:
64
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
}
openfold/data/data_pipeline.py
View file @
4bd1b4d5
...
@@ -14,8 +14,13 @@
...
@@ -14,8 +14,13 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
collections
import
contextlib
import
dataclasses
import
datetime
import
datetime
import
json
from
multiprocessing
import
cpu_count
from
multiprocessing
import
cpu_count
import
tempfile
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
,
MutableMapping
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -26,7 +31,9 @@ from openfold.data import (
...
@@ -26,7 +31,9 @@ from openfold.data import (
mmcif_parsing
,
mmcif_parsing
,
msa_identifiers
,
msa_identifiers
,
msa_pairing
,
msa_pairing
,
feature_processing_multimer
,
)
)
from
openfold.data.parsers
import
Msa
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
,
hmmsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
,
protein
from
openfold.np
import
residue_constants
,
protein
...
@@ -59,8 +66,6 @@ def make_template_features(
...
@@ -59,8 +66,6 @@ def make_template_features(
else
:
else
:
templates_result
=
template_featurizer
.
get_templates
(
templates_result
=
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
query_pdb_code
,
query_release_date
=
query_release_date
,
hits
=
hits_cat
,
hits
=
hits_cat
,
)
)
template_features
=
templates_result
.
features
template_features
=
templates_result
.
features
...
@@ -195,7 +200,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
...
@@ -195,7 +200,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
int_msa
=
[]
int_msa
=
[]
deletion_matrix
=
[]
deletion_matrix
=
[]
uniprot_accession_ids
=
[]
species_ids
=
[]
species_ids
=
[]
seen_sequences
=
set
()
seen_sequences
=
set
()
for
msa_index
,
msa
in
enumerate
(
msas
):
for
msa_index
,
msa
in
enumerate
(
msas
):
...
@@ -215,9 +219,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
...
@@ -215,9 +219,6 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
identifiers
=
msa_identifiers
.
get_identifiers
(
identifiers
=
msa_identifiers
.
get_identifiers
(
msa
.
descriptions
[
sequence_index
]
msa
.
descriptions
[
sequence_index
]
)
)
uniprot_accession_ids
.
append
(
identifiers
.
uniprot_accession_id
.
encode
(
'utf-8'
)
)
species_ids
.
append
(
identifiers
.
species_id
.
encode
(
'utf-8'
))
species_ids
.
append
(
identifiers
.
species_id
.
encode
(
'utf-8'
))
num_res
=
len
(
msas
[
0
].
sequences
[
0
])
num_res
=
len
(
msas
[
0
].
sequences
[
0
])
...
@@ -228,42 +229,25 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
...
@@ -228,42 +229,25 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features
[
"num_alignments"
]
=
np
.
array
(
features
[
"num_alignments"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
)
features
[
"msa_uniprot_accession_identifiers"
]
=
np
.
array
(
uniprot_accession_ids
,
dtype
=
np
.
object_
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object_
)
features
[
"msa_species_identifiers"
]
=
np
.
array
(
species_ids
,
dtype
=
np
.
object_
)
return
features
return
features
def
run_msa_tool
(
def
run_msa_tool
(
msa_runner
,
msa_runner
,
input_
fasta_path
:
str
,
fasta_path
:
str
,
msa_out_path
:
str
,
msa_out_path
:
str
,
msa_format
:
str
,
msa_format
:
str
,
use_precomputed_msas
:
bool
,
max_sto_sequences
:
Optional
[
int
]
=
None
,
max_sto_sequences
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
"""Runs an MSA tool, checking if output already exists first."""
"""Runs an MSA tool, checking if output already exists first."""
if
not
use_precomputed_msas
or
not
os
.
path
.
exists
(
msa_out_path
):
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
result
=
msa_runner
.
query
(
fasta_path
,
max_sto_sequences
)[
0
]
result
=
msa_runner
.
query
(
input_fasta_path
,
max_sto_sequences
)[
0
]
else
:
result
=
msa_runner
.
query
(
input_fasta_path
)[
0
]
result_a3m
=
parsers
.
convert_stockholm_to_a3m
(
result
[
"sto"
])
with
open
(
msa_out_path
,
"w"
)
as
f
:
f
.
write
(
result_a3m
)
else
:
else
:
logging
.
warning
(
"Reading MSA from file %s"
,
msa_out_path
)
result
=
msa_runner
.
query
(
fasta_path
)[
0
]
if
(
msa_format
==
"sto"
and
max_sto_sequences
is
not
None
):
precomputed_msa
=
parsers
.
truncate_stockholm_msa
(
with
open
(
msa_out_path
,
"w"
)
as
f
:
msa_out_path
,
f
.
write
(
result
[
msa_format
])
max_sto_sequences
,
)
result
=
{
"sto"
:
precomputed_msa
}
else
:
with
open
(
msa_out_path
,
"r"
)
as
f
:
result
=
{
msa_format
:
f
.
read
()}
return
result
return
result
...
@@ -413,7 +397,7 @@ class AlignmentRunner:
...
@@ -413,7 +397,7 @@ class AlignmentRunner:
jackhmmer_uniref90_result
=
run_msa_tool
(
jackhmmer_uniref90_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
input_
fasta_path
=
fasta_path
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniref90_out_path
,
msa_out_path
=
uniref90_out_path
,
msa_format
=
'sto'
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniref_max_hits
,
max_sto_sequences
=
self
.
uniref_max_hits
,
...
@@ -427,13 +411,17 @@ class AlignmentRunner:
...
@@ -427,13 +411,17 @@ class AlignmentRunner:
if
(
self
.
template_searcher
is
not
None
):
if
(
self
.
template_searcher
is
not
None
):
if
(
self
.
template_searcher
.
input_format
==
"sto"
):
if
(
self
.
template_searcher
.
input_format
==
"sto"
):
pdb_templates_result
=
self
.
template_searcher
.
query
(
template_msa
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
template_msa
,
output_dir
=
output_dir
)
elif
(
self
.
template_searcher
.
input_format
==
"a3m"
):
elif
(
self
.
template_searcher
.
input_format
==
"a3m"
):
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
template_msa
template_msa
)
)
pdb_templates_result
=
self
.
template_searcher
.
query
(
pdb_templates_result
=
self
.
template_searcher
.
query
(
uniref90_msa_as_a3m
uniref90_msa_as_a3m
,
output_dir
=
output_dir
)
)
else
:
else
:
fmt
=
self
.
template_searcher
.
input_format
fmt
=
self
.
template_searcher
.
input_format
...
@@ -445,7 +433,7 @@ class AlignmentRunner:
...
@@ -445,7 +433,7 @@ class AlignmentRunner:
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.a3m"
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.a3m"
)
jackhmmer_mgnify_result
=
run_msa_tool
(
jackhmmer_mgnify_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
input_
fasta_path
=
fasta_path
,
fasta_path
=
fasta_path
,
msa_out_path
=
mgnify_out_path
,
msa_out_path
=
mgnify_out_path
,
msa_format
=
'sto'
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
mgnify_max_hits
max_sto_sequences
=
self
.
mgnify_max_hits
...
@@ -455,7 +443,7 @@ class AlignmentRunner:
...
@@ -455,7 +443,7 @@ class AlignmentRunner:
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
jackhmmer_small_bfd_result
=
run_msa_tool
(
jackhmmer_small_bfd_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_small_bfd_runner
,
msa_runner
=
self
.
jackhmmer_small_bfd_runner
,
input_
fasta_path
=
fasta_path
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"sto"
,
msa_format
=
"sto"
,
)
)
...
@@ -463,7 +451,7 @@ class AlignmentRunner:
...
@@ -463,7 +451,7 @@ class AlignmentRunner:
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
hhblits_bfd_uniclust_result
=
run_msa_tool
(
hhblits_bfd_uniclust_result
=
run_msa_tool
(
msa_runner
=
self
.
hhblits_bfd_uniclust_runner
,
msa_runner
=
self
.
hhblits_bfd_uniclust_runner
,
input_
fasta_path
=
fasta_path
,
fasta_path
=
fasta_path
,
msa_out_path
=
bfd_out_path
,
msa_out_path
=
bfd_out_path
,
msa_format
=
"a3m"
,
msa_format
=
"a3m"
,
)
)
...
@@ -472,7 +460,7 @@ class AlignmentRunner:
...
@@ -472,7 +460,7 @@ class AlignmentRunner:
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
uniprot_out_path
=
os
.
path
.
join
(
output_dir
,
'uniprot_hits.sto'
)
result
=
run_msa_tool
(
result
=
run_msa_tool
(
self
.
jackhmmer_uniprot_runner
,
self
.
jackhmmer_uniprot_runner
,
input_
fasta_path
=
input_
fasta_path
,
fasta_path
=
fasta_path
,
msa_out_path
=
uniprot_out_path
,
msa_out_path
=
uniprot_out_path
,
msa_format
=
'sto'
,
msa_format
=
'sto'
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
max_sto_sequences
=
self
.
uniprot_max_hits
,
...
@@ -485,7 +473,7 @@ class _FastaChain:
...
@@ -485,7 +473,7 @@ class _FastaChain:
description
:
str
description
:
str
def
_make_chain_id_map
(
*
,
def
_make_chain_id_map
(
sequences
:
Sequence
[
str
],
sequences
:
Sequence
[
str
],
descriptions
:
Sequence
[
str
],
descriptions
:
Sequence
[
str
],
)
->
Mapping
[
str
,
_FastaChain
]:
)
->
Mapping
[
str
,
_FastaChain
]:
...
@@ -498,9 +486,11 @@ def _make_chain_id_map(*,
...
@@ -498,9 +486,11 @@ def _make_chain_id_map(*,
f
'Got
{
len
(
sequences
)
}
chains.'
)
f
'Got
{
len
(
sequences
)
}
chains.'
)
chain_id_map
=
{}
chain_id_map
=
{}
for
chain_id
,
sequence
,
description
in
zip
(
for
chain_id
,
sequence
,
description
in
zip
(
protein
.
PDB_CHAIN_IDS
,
sequences
,
descriptions
):
protein
.
PDB_CHAIN_IDS
,
sequences
,
descriptions
chain_id_map
[
chain_id
]
=
_FastaChain
(
):
sequence
=
sequence
,
description
=
description
)
chain_id_map
[
chain_id
]
=
_FastaChain
(
sequence
=
sequence
,
description
=
description
)
return
chain_id_map
return
chain_id_map
...
@@ -520,7 +510,8 @@ def convert_monomer_features(
...
@@ -520,7 +510,8 @@ def convert_monomer_features(
converted
=
{}
converted
=
{}
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
converted
[
'auth_chain_id'
]
=
np
.
asarray
(
chain_id
,
dtype
=
np
.
object_
)
unnecessary_leading_dim_feats
=
{
unnecessary_leading_dim_feats
=
{
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
'sequence'
,
'domain_name'
,
'num_alignments'
,
'seq_length'
}
for
feature_name
,
feature
in
monomer_features
.
items
():
for
feature_name
,
feature
in
monomer_features
.
items
():
if
feature_name
in
unnecessary_leading_dim_feats
:
if
feature_name
in
unnecessary_leading_dim_feats
:
# asarray ensures it's a np.ndarray.
# asarray ensures it's a np.ndarray.
...
@@ -591,9 +582,15 @@ def add_assembly_features(
...
@@ -591,9 +582,15 @@ def add_assembly_features(
new_all_chain_features
[
new_all_chain_features
[
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
f
'
{
int_id_to_str_id
(
entity_id
)
}
_
{
sym_id
}
'
]
=
chain_features
seq_length
=
chain_features
[
'seq_length'
]
seq_length
=
chain_features
[
'seq_length'
]
chain_features
[
'asym_id'
]
=
chain_id
*
np
.
ones
(
seq_length
)
chain_features
[
'asym_id'
]
=
(
chain_features
[
'sym_id'
]
=
sym_id
*
np
.
ones
(
seq_length
)
chain_id
*
np
.
ones
(
seq_length
)
chain_features
[
'entity_id'
]
=
entity_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_features
[
'sym_id'
]
=
(
sym_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_features
[
'entity_id'
]
=
(
entity_id
*
np
.
ones
(
seq_length
)
).
astype
(
np
.
int64
)
chain_id
+=
1
chain_id
+=
1
return
new_all_chain_features
return
new_all_chain_features
...
@@ -624,8 +621,7 @@ class DataPipeline:
...
@@ -624,8 +621,7 @@ class DataPipeline:
alignment_dir
:
str
,
alignment_dir
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
,
_alignment_index
:
Optional
[
Any
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa_data
=
{}
msas
=
{}
if
(
_alignment_index
is
not
None
):
if
(
_alignment_index
is
not
None
):
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
fp
=
open
(
os
.
path
.
join
(
alignment_dir
,
_alignment_index
[
"db"
]),
"rb"
)
...
@@ -635,14 +631,16 @@ class DataPipeline:
...
@@ -635,14 +631,16 @@ class DataPipeline:
return
msa
return
msa
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
for
(
name
,
start
,
size
)
in
_alignment_index
[
"files"
]:
ext
=
os
.
path
.
splitext
(
name
)
[
-
1
]
filename
,
ext
=
os
.
path
.
splitext
(
name
)
if
(
ext
==
".a3m"
):
if
(
ext
==
".a3m"
):
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
read_msa
(
start
,
size
)
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
):
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
)
read_msa
(
start
,
size
)
)
)
...
@@ -656,28 +654,27 @@ class DataPipeline:
...
@@ -656,28 +654,27 @@ class DataPipeline:
else
:
else
:
for
f
in
os
.
listdir
(
alignment_dir
):
for
f
in
os
.
listdir
(
alignment_dir
):
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
path
=
os
.
path
.
join
(
alignment_dir
,
f
)
ext
=
os
.
path
.
splitext
(
f
)
[
-
1
]
filename
,
ext
=
os
.
path
.
splitext
(
f
)
if
(
ext
==
".a3m"
):
if
(
ext
==
".a3m"
):
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
=
parsers
.
parse_a3m
(
fp
.
read
())
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
elif
(
ext
==
".sto"
):
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
,
deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
msa
=
parsers
.
parse_stockholm
(
fp
.
read
()
fp
.
read
()
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
deletion_matrix
}
else
:
else
:
continue
continue
msa
_data
[
f
]
=
dat
a
msa
s
[
f
]
=
ms
a
return
msa
_data
return
msa
s
def
_parse_template_hits
(
def
_parse_template_hit
_file
s
(
self
,
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
str
,
_alignment_index
:
Optional
[
Any
]
=
None
_alignment_index
:
Optional
[
Any
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
all_hits
=
{}
all_hits
=
{}
...
@@ -694,6 +691,12 @@ class DataPipeline:
...
@@ -694,6 +691,12 @@ class DataPipeline:
if
(
ext
==
".hhr"
):
if
(
ext
==
".hhr"
):
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
hits
=
parsers
.
parse_hhr
(
read_template
(
start
,
size
))
all_hits
[
name
]
=
hits
all_hits
[
name
]
=
hits
elif
(
name
==
"hmmsearch_output.sto"
):
hits
=
parsers
.
parse_hmmsearch_sto
(
read_template
(
start
,
size
),
input_sequence
,
)
all_hits
[
name
]
=
hits
fp
.
close
()
fp
.
close
()
else
:
else
:
...
@@ -705,6 +708,13 @@ class DataPipeline:
...
@@ -705,6 +708,13 @@ class DataPipeline:
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
all_hits
[
f
]
=
hits
elif
(
f
==
"hmm_output.sto"
):
with
open
(
path
,
"r"
)
as
fp
:
hits
=
parsers
.
parse_hmmsearch_sto
(
fp
.
read
(),
input_sequence
,
)
all_hits
[
f
]
=
hits
return
all_hits
return
all_hits
...
@@ -714,9 +724,9 @@ class DataPipeline:
...
@@ -714,9 +724,9 @@ class DataPipeline:
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
_alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msa
_data
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
msa
s
=
self
.
_parse_msa_data
(
alignment_dir
,
_alignment_index
)
if
(
len
(
msa
_data
)
==
0
):
if
(
len
(
msa
s
)
==
0
):
if
(
input_sequence
is
None
):
if
(
input_sequence
is
None
):
raise
ValueError
(
raise
ValueError
(
"""
"""
...
@@ -724,18 +734,13 @@ class DataPipeline:
...
@@ -724,18 +734,13 @@ class DataPipeline:
must be provided.
must be provided.
"""
"""
)
)
msa_data
[
"dummy"
]
=
{
msa_data
[
"dummy"
]
=
Msa
(
"msa"
:
[
input_sequence
],
[
input_sequence
],
"deletion_matrix"
:
[[
0
for
_
in
input_sequence
]],
[[
0
for
_
in
input_sequence
]],
}
[
"dummy"
]
)
msas
,
deletion_matrices
=
zip
(
*
[
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
msa_objects
=
[
Msa
(
m
,
d
)
for
m
,
d
in
zip
(
msas
,
deletion_matrices
)]
msa_features
=
make_msa_features
(
msa_objects
)
msa_features
=
make_msa_features
(
list
(
msas
.
values
())
)
return
msa_features
return
msa_features
...
@@ -757,7 +762,12 @@ class DataPipeline:
...
@@ -757,7 +762,12 @@ class DataPipeline:
input_description
=
input_descs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
num_res
=
len
(
input_sequence
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hit_files
(
alignment_dir
,
input_sequence
,
_alignment_index
,
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -801,7 +811,10 @@ class DataPipeline:
...
@@ -801,7 +811,10 @@ class DataPipeline:
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
input_sequence
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -836,7 +849,11 @@ class DataPipeline:
...
@@ -836,7 +849,11 @@ class DataPipeline:
is_distillation
is_distillation
)
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
input_sequence
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -864,7 +881,11 @@ class DataPipeline:
...
@@ -864,7 +881,11 @@ class DataPipeline:
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
description
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
core_path
))[
0
].
upper
()
core_feats
=
make_protein_features
(
protein_object
,
description
)
core_feats
=
make_protein_features
(
protein_object
,
description
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
_alignment_index
)
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
input_sequence
,
_alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
input_sequence
,
input_sequence
,
hits
,
hits
,
...
@@ -881,117 +902,103 @@ class DataPipelineMultimer:
...
@@ -881,117 +902,103 @@ class DataPipelineMultimer:
def
__init__
(
self
,
def
__init__
(
self
,
monomer_data_pipeline
:
DataPipeline
,
monomer_data_pipeline
:
DataPipeline
,
jackhmmer_binary_path
:
str
,
uniprot_database_path
:
str
,
max_uniprot_hits
:
int
=
50000
,
):
):
"""Initializes the data pipeline.
"""Initializes the data pipeline.
Args:
Args:
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs
the data pipeline for the monomer AlphaFold system.
the data pipeline for the monomer AlphaFold system.
jackhmmer_binary_path: Location of the jackhmmer binary.
jackhmmer_binary_path: Location of the jackhmmer binary.
uniprot_database_path: Location of the unclustered uniprot sequences, that
uniprot_database_path: Location of the unclustered uniprot sequences, that
will be searched with jackhmmer and used for MSA pairing.
will be searched with jackhmmer and used for MSA pairing.
max_uniprot_hits: The maximum number of hits to return from uniprot.
max_uniprot_hits: The maximum number of hits to return from uniprot.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
"""
"""
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
self
.
_monomer_data_pipeline
=
monomer_data_pipeline
def
_process_single_chain
(
def
_process_single_chain
(
self
,
self
,
chain_id
:
str
,
chain_id
:
str
,
sequence
:
str
,
sequence
:
str
,
description
:
str
,
description
:
str
,
msa_outpu
t_dir
:
str
,
chain_alignmen
t_dir
:
str
,
is_homomer_or_monomer
:
bool
is_homomer_or_monomer
:
bool
)
->
FeatureDict
:
)
->
FeatureDict
:
"""Runs the monomer pipeline on a single chain."""
"""Runs the monomer pipeline on a single chain."""
chain_fasta_str
=
f
'>chain_
{
chain_id
}
\n
{
sequence
}
\n
'
chain_fasta_str
=
f
'>
{
chain_id
}
\n
{
sequence
}
\n
'
chain_msa_output_dir
=
os
.
path
.
join
(
msa_output_dir
,
chain_id
)
if
not
os
.
path
.
exists
(
chain_alignment_dir
):
if
not
os
.
path
.
exists
(
chain_msa_output_dir
):
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
raise
ValueError
(
f
"Alignments for
{
chain_id
}
not found..."
)
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
with
temp_fasta_file
(
chain_fasta_str
)
as
chain_fasta_path
:
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
chain_features
=
self
.
_monomer_data_pipeline
.
process_fasta
(
input_
fasta_path
=
chain_fasta_path
,
fasta_path
=
chain_fasta_path
,
alignment_dir
=
chain_
msa_outpu
t_dir
alignment_dir
=
chain_
alignmen
t_dir
)
)
# We only construct the pairing features if there are 2 or more unique
# We only construct the pairing features if there are 2 or more unique
# sequences.
# sequences.
if
not
is_homomer_or_monomer
:
if
not
is_homomer_or_monomer
:
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_fasta_path
,
all_seq_msa_features
=
self
.
_all_seq_msa_features
(
chain_msa_output_dir
)
chain_fasta_path
,
chain_alignment_dir
)
chain_features
.
update
(
all_seq_msa_features
)
chain_features
.
update
(
all_seq_msa_features
)
return
chain_features
return
chain_features
def
_all_seq_msa_features
(
self
,
input_fasta_path
,
msa_output_dir
):
def
_all_seq_msa_features
(
self
,
fasta_path
,
alignment_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
"""Get MSA features for unclustered uniprot, for pairing."""
uniprot_msa_path
=
os
.
path
.
join
(
msa_output_dir
,
"uniprot_hits.sto"
)
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
all_seq_features
=
make_msa_features
([
msa
])
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_uniprot_accession_identifiers'
,
'msa_species_identifiers'
,
'msa_species_identifiers'
,
)
)
feats
=
{
feats
=
{
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
f
'
{
k
}
_all_seq'
:
v
for
k
,
v
in
all_seq_features
.
items
()
if
k
in
valid_feats
if
k
in
valid_feats
}
}
return
feats
return
feats
def
process
(
self
,
def
process_fasta
(
self
,
input_fasta_path
:
str
,
fasta_path
:
str
,
msa_output_dir
:
str
,
alignment_dir
:
str
,
is_prokaryote
:
bool
=
False
)
->
FeatureDict
:
)
->
FeatureDict
:
"""
Runs alignment tools on the input sequences and c
reates features."""
"""
C
reates features."""
with
open
(
input_
fasta_path
)
as
f
:
with
open
(
fasta_path
)
as
f
:
input_fasta_str
=
f
.
read
()
input_fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
chain_id_map
=
_make_chain_id_map
(
sequences
=
input_seqs
,
descriptions
=
input_descs
)
chain_id_map_path
=
os
.
path
.
join
(
msa_output_dir
,
'chain_id_map.json'
)
with
open
(
chain_id_map_path
,
'w'
)
as
f
:
chain_id_map_dict
=
{
chain_id
:
dataclasses
.
asdict
(
fasta_chain
)
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
()
}
json
.
dump
(
chain_id_map_dict
,
f
,
indent
=
4
,
sort_keys
=
True
)
all_chain_features
=
{}
all_chain_features
=
{}
sequence_features
=
{}
sequence_features
=
{}
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
is_homomer_or_monomer
=
len
(
set
(
input_seqs
))
==
1
for
chain_id
,
fasta_chain
in
chain_id_map
.
items
():
for
desc
,
seq
in
zip
(
input_descs
,
input_seqs
):
if
fasta_chain
.
sequence
in
sequence_features
:
if
seq
in
sequence_features
:
all_chain_features
[
chain_id
]
=
copy
.
deepcopy
(
all_chain_features
[
desc
]
=
copy
.
deepcopy
(
sequence_features
[
fasta_chain
.
sequence
])
sequence_features
[
seq
]
)
continue
continue
chain_features
=
self
.
_process_single_chain
(
chain_features
=
self
.
_process_single_chain
(
chain_id
=
chain_id
,
chain_id
=
desc
,
sequence
=
fasta_chain
.
sequence
,
sequence
=
seq
,
description
=
fasta_chain
.
description
,
description
=
desc
,
msa_output_dir
=
msa_output_dir
,
chain_alignment_dir
=
os
.
path
.
join
(
alignment_dir
,
desc
)
,
is_homomer_or_monomer
=
is_homomer_or_monomer
is_homomer_or_monomer
=
is_homomer_or_monomer
)
)
chain_features
=
convert_monomer_features
(
chain_features
=
convert_monomer_features
(
chain_features
,
chain_features
,
chain_id
=
chain_id
chain_id
=
desc
)
)
all_chain_features
[
chain_id
]
=
chain_features
all_chain_features
[
desc
]
=
chain_features
sequence_features
[
fasta_chain
.
sequence
]
=
chain_features
sequence_features
[
seq
]
=
chain_features
all_chain_features
=
add_assembly_features
(
all_chain_features
)
all_chain_features
=
add_assembly_features
(
all_chain_features
)
np_example
=
feature_processing
.
pair_and_merge
(
np_example
=
feature_processing
_multimer
.
pair_and_merge
(
all_chain_features
=
all_chain_features
,
all_chain_features
=
all_chain_features
,
is_prokaryote
=
is_prokaryote
,
)
)
# Pad MSA to avoid zero-sized extra_msa.
# Pad MSA to avoid zero-sized extra_msa.
...
...
openfold/data/data_transforms.py
View file @
4bd1b4d5
...
@@ -428,10 +428,16 @@ def make_hhblits_profile(protein):
...
@@ -428,10 +428,16 @@ def make_hhblits_profile(protein):
@
curry1
@
curry1
def
make_masked_msa
(
protein
,
config
,
replace_fraction
):
def
make_masked_msa
(
protein
,
config
,
replace_fraction
,
seed
):
"""Create data for BERT on raw MSA."""
"""Create data for BERT on raw MSA."""
device
=
protein
[
"msa"
].
device
# Add a random amino acid uniformly.
# Add a random amino acid uniformly.
random_aa
=
torch
.
tensor
([
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
)
random_aa
=
torch
.
tensor
(
[
0.05
]
*
20
+
[
0.0
,
0.0
],
dtype
=
torch
.
float32
,
device
=
device
)
categorical_probs
=
(
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
config
.
uniform_prob
*
random_aa
...
@@ -449,11 +455,17 @@ def make_masked_msa(protein, config, replace_fraction):
...
@@ -449,11 +455,17 @@ def make_masked_msa(protein, config, replace_fraction):
)
)
assert
mask_prob
>=
0.0
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
categorical_probs
,
pad_shapes
,
value
=
mask_prob
,
)
)
sh
=
protein
[
"msa"
].
shape
sh
=
protein
[
"msa"
].
shape
mask_position
=
torch
.
rand
(
sh
)
<
replace_fraction
g
=
torch
.
Generator
(
device
=
protein
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
sample
=
torch
.
rand
(
sh
,
device
=
device
,
generator
=
g
)
mask_position
=
sample
<
replace_fraction
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
shaped_categorical
(
categorical_probs
)
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
bert_msa
=
torch
.
where
(
mask_position
,
bert_msa
,
protein
[
"msa"
])
...
...
openfold/data/data_transforms_multimer.py
0 → 100644
View file @
4bd1b4d5
from
typing
import
Sequence
import
torch
from
openfold.data.data_transforms
import
curry1
from
openfold.utils.tensor_utils
import
masked_mean
def
gumbel_noise
(
shape
:
Sequence
[
int
],
device
:
torch
.
device
,
eps
=
1e-6
,
generator
=
None
,
)
->
torch
.
Tensor
:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise
=
torch
.
rand
(
shape
,
dtype
=
torch
.
float32
,
device
=
device
,
generator
=
generator
)
gumbel
=
-
torch
.
log
(
-
torch
.
log
(
uniform_noise
+
eps
)
+
eps
)
return
gumbel
def
gumbel_max_sample
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
nn
.
functional
.
one_hot
(
torch
.
argmax
(
logits
+
z
,
dim
=-
1
),
logits
.
shape
[
-
1
],
)
def
gumbel_argsort_sample_idx
(
logits
:
torch
.
Tensor
,
generator
=
None
)
->
torch
.
Tensor
:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z
=
gumbel_noise
(
logits
.
shape
,
device
=
logits
.
device
,
generator
=
generator
)
return
torch
.
argsort
(
logits
+
z
,
dim
=-
1
,
descending
=
True
)
@
curry1
def
make_masked_msa
(
batch
,
config
,
replace_fraction
,
seed
,
eps
=
1e-6
):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa
=
torch
.
Tensor
(
[
0.05
]
*
20
+
[
0.
,
0.
],
device
=
batch
[
'msa'
].
device
)
categorical_probs
=
(
config
.
uniform_prob
*
random_aa
+
config
.
profile_prob
*
batch
[
'msa_profile'
]
+
config
.
same_prob
*
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob
=
1.
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
[
0
,
1
],
value
=
mask_prob
)
sh
=
batch
[
'msa'
].
shape
mask_position
=
torch
.
rand
(
sh
,
device
=
batch
[
'msa'
].
device
)
<
replace_fraction
mask_position
*=
batch
[
'msa_mask'
].
to
(
mask_position
.
dtype
)
logits
=
torch
.
log
(
categorical_probs
+
eps
)
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
bert_msa
=
gumbel_max_sample
(
logits
,
generator
=
g
)
bert_msa
=
torch
.
where
(
mask_position
,
torch
.
argmax
(
bert_msa
,
dim
=-
1
),
batch
[
'msa'
]
)
bert_msa
*=
batch
[
'msa_mask'
].
to
(
bert_msa
.
dtype
)
# Mix real and masked MSA.
if
'bert_mask'
in
batch
:
batch
[
'bert_mask'
]
*=
mask_position
.
to
(
torch
.
float32
)
else
:
batch
[
'bert_mask'
]
=
mask_position
.
to
(
torch
.
float32
)
batch
[
'true_msa'
]
=
batch
[
'msa'
]
batch
[
'msa'
]
=
bert_msa
return
batch
@
curry1
def
nearest_neighbor_clusters
(
batch
,
gap_agreement_weight
=
0.
):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device
=
batch
[
"msa_mask"
].
device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights
=
torch
.
Tensor
(
[
1.
]
*
21
+
[
gap_agreement_weight
]
+
[
0.
],
device
=
device
,
)
msa_mask
=
batch
[
'msa_mask'
]
msa_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
extra_mask
=
batch
[
'extra_msa_mask'
]
extra_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'extra_msa'
],
23
)
msa_one_hot_masked
=
msa_mask
[:,
:,
None
]
*
msa_one_hot
extra_one_hot_masked
=
extra_mask
[:,
:,
None
]
*
extra_one_hot
agreement
=
torch
.
einsum
(
'mrc, nrc->nm'
,
extra_one_hot_masked
,
weights
*
msa_one_hot_masked
)
cluster_assignment
=
torch
.
nn
.
functional
.
softmax
(
1e3
*
agreement
,
dim
=
0
)
cluster_assignment
*=
torch
.
einsum
(
'mr, nr->mn'
,
msa_mask
,
extra_mask
)
cluster_count
=
torch
.
sum
(
cluster_assignment
,
dim
=-
1
)
cluster_count
+=
1.
# We always include the sequence itself.
msa_sum
=
torch
.
einsum
(
'nm, mrc->nrc'
,
cluster_assignment
,
extra_one_hot_masked
)
msa_sum
+=
msa_one_hot_masked
cluster_profile
=
msa_sum
/
cluster_count
[:,
None
,
None
]
extra_deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
deletion_matrix
=
batch
[
'deletion_matrix'
]
del_sum
=
torch
.
einsum
(
'nm, mc->nc'
,
cluster_assignment
,
extra_mask
*
extra_deletion_matrix
)
del_sum
+=
deletion_matrix
# Original sequence.
cluster_deletion_mean
=
del_sum
/
cluster_count
[:,
None
]
batch
[
'cluster_profile'
]
=
cluster_profile
batch
[
'cluster_deletion_mean'
]
=
cluster_deletion_mean
return
batch
def
create_target_feat
(
batch
):
"""Create the target features"""
batch
[
"target_feat"
]
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
"aatype"
],
21
).
to
(
torch
.
float32
)
return
batch
def
create_msa_feat
(
batch
):
"""Create and concatenate MSA features."""
device
=
batch
[
"msa"
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
23
)
deletion_matrix
=
batch
[
'deletion_matrix'
]
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
deletion_mean_value
=
(
torch
.
atan
(
batch
[
'cluster_deletion_mean'
]
/
3.
)
*
(
2.
/
pi
)
)[...,
None
]
msa_feat
=
torch
.
cat
(
[
msa_1hot
,
has_deletion
,
deletion_value
,
batch
[
'cluster_profile'
],
deletion_mean_value
],
dim
=-
1
,
)
batch
[
"msa_feat"
]
=
msa_feat
return
batch
def
build_extra_msa_feat
(
batch
):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa
=
batch
[
'extra_msa'
]
deletion_matrix
=
batch
[
'extra_deletion_matrix'
]
msa_1hot
=
torch
.
nn
.
functional
.
one_hot
(
extra_msa
,
23
)
has_deletion
=
torch
.
clamp
(
deletion_matrix
,
min
=
0.
,
max
=
1.
)[...,
None
]
pi
=
torch
.
acos
(
torch
.
zeros
(
1
,
device
=
deletion_matrix
.
device
))
*
2
deletion_value
=
(
(
torch
.
atan
(
deletion_matrix
/
3.
)
*
(
2.
/
pi
))[...,
None
]
)
extra_msa_mask
=
batch
[
'extra_msa_mask'
]
catted
=
torch
.
cat
([
msa_1hot
,
has_deletion
,
deletion_value
],
dim
=-
1
)
return
catted
@
curry1
def
sample_msa
(
batch
,
max_seq
,
max_extra_msa_seq
,
seed
,
inf
=
1e6
):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g
=
torch
.
Generator
(
device
=
batch
[
"msa"
].
device
)
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
# Sample uniformly among sequences with at least one non-masked position.
logits
=
(
torch
.
clamp
(
torch
.
sum
(
batch
[
'msa_mask'
],
dim
=-
1
),
0.
,
1.
)
-
1.
)
*
inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if
'cluster_bias_mask'
not
in
batch
:
cluster_bias_mask
=
torch
.
nn
.
functional
.
pad
(
batch
[
'msa'
].
new_zeros
(
batch
[
'msa'
].
shape
[
0
]
-
1
),
(
1
,
0
),
value
=
1.
)
else
:
cluster_bias_mask
=
batch
[
'cluster_bias_mask'
]
logits
+=
cluster_bias_mask
*
inf
index_order
=
gumbel_argsort_sample_idx
(
logits
,
generator
=
g
)
sel_idx
=
index_order
[:
max_seq
]
extra_idx
=
index_order
[
max_seq
:][:
max_extra_msa_seq
]
for
k
in
[
'msa'
,
'deletion_matrix'
,
'msa_mask'
,
'bert_mask'
]:
if
k
in
batch
:
batch
[
'extra_'
+
k
]
=
batch
[
k
][
extra_idx
]
batch
[
k
]
=
batch
[
k
][
sel_idx
]
return
batch
def
make_msa_profile
(
batch
):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch
[
"msa_profile"
]
=
masked_mean
(
batch
[
'msa_mask'
][...,
None
],
torch
.
nn
.
functional
.
one_hot
(
batch
[
'msa'
],
22
),
dim
=-
3
,
)
return
batch
openfold/data/feature_pipeline.py
View file @
4bd1b4d5
...
@@ -20,7 +20,7 @@ import ml_collections
...
@@ -20,7 +20,7 @@ import ml_collections
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
openfold.data
import
input_pipeline
from
openfold.data
import
input_pipeline
,
input_pipeline_multimer
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
...
@@ -73,8 +73,10 @@ def np_example_to_features(
...
@@ -73,8 +73,10 @@ def np_example_to_features(
np_example
:
FeatureDict
,
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
mode
:
str
,
is_multimer
:
bool
=
False
):
):
np_example
=
dict
(
np_example
)
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
...
@@ -87,11 +89,18 @@ def np_example_to_features(
...
@@ -87,11 +89,18 @@ def np_example_to_features(
np_example
=
np_example
,
features
=
feature_names
np_example
=
np_example
,
features
=
feature_names
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
if
(
not
is_multimer
):
tensor_dict
,
features
=
input_pipeline
.
process_tensors_from_config
(
cfg
.
common
,
tensor_dict
,
cfg
[
mode
],
cfg
.
common
,
)
cfg
[
mode
],
)
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
@@ -106,10 +115,15 @@ class FeaturePipeline:
...
@@ -106,10 +115,15 @@ class FeaturePipeline:
def
process_features
(
def
process_features
(
self
,
self
,
raw_features
:
FeatureDict
,
raw_features
:
FeatureDict
,
mode
:
str
=
"train"
,
mode
:
str
=
"train"
,
is_multimer
:
bool
=
False
,
)
->
FeatureDict
:
)
->
FeatureDict
:
if
(
is_multimer
and
mode
!=
"predict"
):
raise
ValueError
(
"Multimer mode is not currently trainable"
)
return
np_example_to_features
(
return
np_example_to_features
(
np_example
=
raw_features
,
np_example
=
raw_features
,
config
=
self
.
config
,
config
=
self
.
config
,
mode
=
mode
,
mode
=
mode
,
is_multimer
=
is_multimer
,
)
)
openfold/data/
multimer_
feature_processing.py
→
openfold/data/feature_processing
_multimer
.py
View file @
4bd1b4d5
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Feature processing logic for multimer data pipeline."""
"""Feature processing logic for multimer data pipeline."""
from
typing
import
Iterable
,
MutableMapping
,
List
from
typing
import
Iterable
,
MutableMapping
,
List
,
Mapping
from
openfold.data
import
msa_pairing
from
openfold.data
import
msa_pairing
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
...
@@ -49,13 +49,11 @@ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
...
@@ -49,13 +49,11 @@ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
def
pair_and_merge
(
def
pair_and_merge
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]],
is_prokaryote
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Runs processing on features to augment, pair and merge.
"""Runs processing on features to augment, pair and merge.
Args:
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns:
Returns:
A dictionary of features.
A dictionary of features.
...
@@ -69,7 +67,8 @@ def pair_and_merge(
...
@@ -69,7 +67,8 @@ def pair_and_merge(
if
pair_msa_sequences
:
if
pair_msa_sequences
:
np_chains_list
=
msa_pairing
.
create_paired_features
(
np_chains_list
=
msa_pairing
.
create_paired_features
(
chains
=
np_chains_list
,
prokaryotic
=
is_prokaryote
)
chains
=
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
msa_pairing
.
deduplicate_unpaired_sequences
(
np_chains_list
)
np_chains_list
=
crop_chains
(
np_chains_list
=
crop_chains
(
np_chains_list
,
np_chains_list
,
...
@@ -175,6 +174,7 @@ def process_final(
...
@@ -175,6 +174,7 @@ def process_final(
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_seq_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_make_msa_mask
(
np_example
)
np_example
=
_filter_features
(
np_example
)
np_example
=
_filter_features
(
np_example
)
return
np_example
return
np_example
...
@@ -210,19 +210,23 @@ def _filter_features(
...
@@ -210,19 +210,23 @@ def _filter_features(
def
process_unmerged_features
(
def
process_unmerged_features
(
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]):
all_chain_features
:
MutableMapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]
):
"""Postprocessing stage for per-chain features before merging."""
"""Postprocessing stage for per-chain features before merging."""
num_chains
=
len
(
all_chain_features
)
num_chains
=
len
(
all_chain_features
)
for
chain_features
in
all_chain_features
.
values
():
for
chain_features
in
all_chain_features
.
values
():
# Convert deletion matrices to float.
# Convert deletion matrices to float.
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
[
'deletion_matrix'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
chain_features
.
pop
(
'deletion_matrix_int'
),
dtype
=
np
.
float32
)
if
'deletion_matrix_int_all_seq'
in
chain_features
:
if
'deletion_matrix_int_all_seq'
in
chain_features
:
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
[
'deletion_matrix_all_seq'
]
=
np
.
asarray
(
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
.
pop
(
'deletion_matrix_int_all_seq'
),
dtype
=
np
.
float32
)
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_mean'
]
=
np
.
mean
(
chain_features
[
'deletion_matrix'
],
axis
=
0
)
chain_features
[
'deletion_matrix'
],
axis
=
0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype.
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
all_atom_mask
=
residue_constants
.
STANDARD_ATOM_MASK
[
...
...
openfold/data/input_pipeline_multimer.py
0 → 100644
View file @
4bd1b4d5
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
torch
from
openfold.data
import
(
data_transforms
,
data_transforms_multimer
,
)
def
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
):
"""Input pipeline data transformers that are not ensembled."""
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms_multimer
.
make_msa_profile
,
data_transforms_multimer
.
create_target_feat
,
]
if
(
common_cfg
.
use_templates
):
transforms
.
extend
([
data_transforms
.
make_pseudo_beta
(
"template_"
),
])
return
transforms
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
msa_seed
=
None
if
(
not
common_cfg
.
resample_msa_in_recycling
):
msa_seed
=
ensemble_seed
transforms
.
append
(
data_transforms_multimer
.
sample_msa
(
max_msa_clusters
,
max_extra_msa
,
seed
=
msa_seed
,
)
)
if
"masked_msa"
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms_multimer
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
,
seed
=
(
msa_seed
+
1
)
if
msa_seed
else
None
,
)
)
transforms
.
append
(
data_transforms_multimer
.
nearest_neighbor_clusters
())
transforms
.
append
(
data_transforms_multimer
.
create_msa_feat
)
return
transforms
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed
=
torch
.
Generator
().
seed
()
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
ensemble_seed
,
)
fn
=
compose
(
fns
)
d
[
"ensemble_index"
]
=
i
return
fn
(
d
)
no_templates
=
True
if
(
"template_aatype"
in
tensors
):
no_templates
=
tensors
[
"template_aatype"
].
shape
[
0
]
==
0
nonensembled
=
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
,
)
tensors
=
compose
(
nonensembled
)(
tensors
)
if
(
"no_recycling_iters"
in
tensors
):
num_recycling
=
int
(
tensors
[
"no_recycling_iters"
])
else
:
num_recycling
=
common_cfg
.
max_recycling_iters
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
for
f
in
fs
:
x
=
f
(
x
)
return
x
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
openfold/data/msa_identifiers.py
View file @
4bd1b4d5
...
@@ -48,7 +48,6 @@ _UNIPROT_PATTERN = re.compile(
...
@@ -48,7 +48,6 @@ _UNIPROT_PATTERN = re.compile(
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
Identifiers
:
class
Identifiers
:
uniprot_accession_id
:
str
=
''
species_id
:
str
=
''
species_id
:
str
=
''
...
@@ -69,8 +68,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
...
@@ -69,8 +68,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
matches
=
re
.
search
(
_UNIPROT_PATTERN
,
msa_sequence_identifier
.
strip
())
if
matches
:
if
matches
:
return
Identifiers
(
return
Identifiers
(
uniprot_accession
_id
=
matches
.
group
(
'
Accession
Identifier'
)
,
species
_id
=
matches
.
group
(
'
Species
Identifier'
)
species_id
=
matches
.
group
(
'SpeciesIdentifier'
)
)
)
return
Identifiers
()
return
Identifiers
()
...
...
openfold/data/msa_pairing.py
View file @
4bd1b4d5
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
collections
import
collections
import
functools
import
functools
import
string
import
string
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Sequence
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Sequence
,
Mapping
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
...
@@ -27,12 +27,6 @@ from openfold.np import residue_constants
...
@@ -27,12 +27,6 @@ from openfold.np import residue_constants
# TODO: This stuff should probably also be in a config
# TODO: This stuff should probably also be in a config
ALPHA_ACCESSION_ID_MAP
=
{
x
:
y
for
y
,
x
in
enumerate
(
string
.
ascii_uppercase
)}
ALPHANUM_ACCESSION_ID_MAP
=
{
chr
:
num
for
num
,
chr
in
enumerate
(
string
.
ascii_uppercase
+
string
.
digits
)
}
# A-Z,0-9
NUM_ACCESSION_ID_MAP
=
{
str
(
x
):
x
for
x
in
range
(
10
)}
# 0-9
MSA_GAP_IDX
=
residue_constants
.
restypes_with_x_and_gap
.
index
(
'-'
)
MSA_GAP_IDX
=
residue_constants
.
restypes_with_x_and_gap
.
index
(
'-'
)
SEQUENCE_GAP_CUTOFF
=
0.5
SEQUENCE_GAP_CUTOFF
=
0.5
SEQUENCE_SIMILARITY_CUTOFF
=
0.9
SEQUENCE_SIMILARITY_CUTOFF
=
0.9
...
@@ -61,14 +55,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
...
@@ -61,14 +55,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
def
create_paired_features
(
def
create_paired_features
(
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]],
chains
:
Iterable
[
Mapping
[
str
,
np
.
ndarray
]],
prokaryotic
:
bool
,
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
)
->
List
[
Mapping
[
str
,
np
.
ndarray
]]:
"""Returns the original chains with paired NUM_SEQ features.
"""Returns the original chains with paired NUM_SEQ features.
Args:
Args:
chains: A list of feature dictionaries for each chain.
chains: A list of feature dictionaries for each chain.
prokaryotic: Whether the target complex is from a prokaryotic organism.
Used to determine the distance metric for pairing.
Returns:
Returns:
A list of feature dictionaries with sequence features including only
A list of feature dictionaries with sequence features including only
...
@@ -81,8 +72,7 @@ def create_paired_features(
...
@@ -81,8 +72,7 @@ def create_paired_features(
return
chains
return
chains
else
:
else
:
updated_chains
=
[]
updated_chains
=
[]
paired_chains_to_paired_row_indices
=
pair_sequences
(
paired_chains_to_paired_row_indices
=
pair_sequences
(
chains
)
chains
,
prokaryotic
)
paired_rows
=
reorder_paired_rows
(
paired_rows
=
reorder_paired_rows
(
paired_chains_to_paired_row_indices
)
paired_chains_to_paired_row_indices
)
...
@@ -117,8 +107,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
...
@@ -117,8 +107,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
num_res
=
feature
.
shape
[
1
]
num_res
=
feature
.
shape
[
1
]
padding
=
MSA_PAD_VALUES
[
feature_name
]
*
np
.
ones
([
1
,
num_res
],
padding
=
MSA_PAD_VALUES
[
feature_name
]
*
np
.
ones
([
1
,
num_res
],
feature
.
dtype
)
feature
.
dtype
)
elif
feature_name
in
(
'msa_uniprot_accession_identifiers_all_seq'
,
elif
feature_name
==
'msa_species_identifiers_all_seq'
:
'msa_species_identifiers_all_seq'
):
padding
=
[
b
''
]
padding
=
[
b
''
]
else
:
else
:
return
feature
return
feature
...
@@ -136,11 +125,9 @@ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
...
@@ -136,11 +125,9 @@ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
msa_df
=
pd
.
DataFrame
({
msa_df
=
pd
.
DataFrame
({
'msa_species_identifiers'
:
'msa_species_identifiers'
:
chain_features
[
'msa_species_identifiers_all_seq'
],
chain_features
[
'msa_species_identifiers_all_seq'
],
'msa_uniprot_accession_identifiers'
:
chain_features
[
'msa_uniprot_accession_identifiers_all_seq'
],
'msa_row'
:
'msa_row'
:
np
.
arange
(
len
(
np
.
arange
(
len
(
chain_features
[
'msa_
uniprot_accession
_identifiers_all_seq'
])),
chain_features
[
'msa_
species
_identifiers_all_seq'
])),
'msa_similarity'
:
per_seq_similarity
,
'msa_similarity'
:
per_seq_similarity
,
'gap'
:
per_seq_gap
'gap'
:
per_seq_gap
})
})
...
@@ -155,139 +142,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
...
@@ -155,139 +142,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
return
species_lookup
return
species_lookup
@
functools
.
lru_cache
(
maxsize
=
65536
)
def
encode_accession
(
accession_id
:
str
)
->
int
:
"""Map accession codes to the serial order in which they were assigned."""
alpha
=
ALPHA_ACCESSION_ID_MAP
# A-Z
alphanum
=
ALPHANUM_ACCESSION_ID_MAP
# A-Z,0-9
num
=
NUM_ACCESSION_ID_MAP
# 0-9
coding
=
0
# This is based on the uniprot accession id format
# https://www.uniprot.org/help/accession_numbers
if
accession_id
[
0
]
in
{
'O'
,
'P'
,
'Q'
}:
bases
=
(
alpha
,
num
,
alphanum
,
alphanum
,
alphanum
,
num
)
elif
len
(
accession_id
)
==
6
:
bases
=
(
alpha
,
num
,
alpha
,
alphanum
,
alphanum
,
num
)
elif
len
(
accession_id
)
==
10
:
bases
=
(
alpha
,
num
,
alpha
,
alphanum
,
alphanum
,
num
,
alpha
,
alphanum
,
alphanum
,
num
)
product
=
1
for
place
,
base
in
zip
(
reversed
(
accession_id
),
reversed
(
bases
)):
coding
+=
base
[
place
]
*
product
product
*=
len
(
base
)
return
coding
def
_calc_id_diff
(
id_a
:
bytes
,
id_b
:
bytes
)
->
int
:
return
abs
(
encode_accession
(
id_a
.
decode
())
-
encode_accession
(
id_b
.
decode
()))
def
_find_all_accession_matches
(
accession_id_lists
:
List
[
List
[
bytes
]],
diff_cutoff
:
int
=
20
)
->
List
[
List
[
Any
]]:
"""Finds accession id matches across the chains based on their difference."""
all_accession_tuples
=
[]
current_tuple
=
[]
tokens_used_in_answer
=
set
()
def
_matches_all_in_current_tuple
(
inp
:
bytes
,
diff_cutoff
:
int
)
->
bool
:
return
all
((
_calc_id_diff
(
s
,
inp
)
<
diff_cutoff
for
s
in
current_tuple
))
def
_all_tokens_not_used_before
()
->
bool
:
return
all
((
s
not
in
tokens_used_in_answer
for
s
in
current_tuple
))
def
dfs
(
level
,
accession_id
,
diff_cutoff
=
diff_cutoff
)
->
None
:
if
level
==
len
(
accession_id_lists
)
-
1
:
if
_all_tokens_not_used_before
():
all_accession_tuples
.
append
(
list
(
current_tuple
))
for
s
in
current_tuple
:
tokens_used_in_answer
.
add
(
s
)
return
if
level
==
-
1
:
new_list
=
accession_id_lists
[
level
+
1
]
else
:
new_list
=
[(
_calc_id_diff
(
accession_id
,
s
),
s
)
for
s
in
accession_id_lists
[
level
+
1
]]
new_list
=
sorted
(
new_list
)
new_list
=
[
s
for
d
,
s
in
new_list
]
for
s
in
new_list
:
if
(
_matches_all_in_current_tuple
(
s
,
diff_cutoff
)
and
s
not
in
tokens_used_in_answer
):
current_tuple
.
append
(
s
)
dfs
(
level
+
1
,
s
)
current_tuple
.
pop
()
dfs
(
-
1
,
''
)
return
all_accession_tuples
def
_accession_row
(
msa_df
:
pd
.
DataFrame
,
accession_id
:
bytes
)
->
pd
.
Series
:
matched_df
=
msa_df
[
msa_df
.
msa_uniprot_accession_identifiers
==
accession_id
]
return
matched_df
.
iloc
[
0
]
def
_match_rows_by_genetic_distance
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
],
cutoff
:
int
=
20
)
->
List
[
List
[
int
]]:
"""Finds MSA sequence pairings across chains within a genetic distance cutoff.
The genetic distance between two sequences is approximated by taking the
difference in their UniProt accession ids.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species. If species is missing for a chain, the
dataframe is set to None.
cutoff: the genetic distance cutoff.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
num_examples
=
len
(
this_species_msa_dfs
)
# N
accession_id_lists
=
[]
# M
match_index_to_chain_index
=
{}
for
chain_index
,
species_df
in
enumerate
(
this_species_msa_dfs
):
if
species_df
is
not
None
:
accession_id_lists
.
append
(
list
(
species_df
.
msa_uniprot_accession_identifiers
.
values
))
# Keep track of which of the this_species_msa_dfs are not None.
match_index_to_chain_index
[
len
(
accession_id_lists
)
-
1
]
=
chain_index
all_accession_id_matches
=
_find_all_accession_matches
(
accession_id_lists
,
cutoff
)
# [k, M]
all_paired_msa_rows
=
[]
# [k, N]
for
accession_id_match
in
all_accession_id_matches
:
paired_msa_rows
=
[]
for
match_index
,
accession_id
in
enumerate
(
accession_id_match
):
# Map back to chain index.
chain_index
=
match_index_to_chain_index
[
match_index
]
seq_series
=
_accession_row
(
this_species_msa_dfs
[
chain_index
],
accession_id
)
if
(
seq_series
.
msa_similarity
>
SEQUENCE_SIMILARITY_CUTOFF
or
seq_series
.
gap
>
SEQUENCE_GAP_CUTOFF
):
continue
else
:
paired_msa_rows
.
append
(
seq_series
.
msa_row
)
# If a sequence is skipped based on sequence similarity to the respective
# target sequence or a gap cuttoff, the lengths of accession_id_match and
# paired_msa_rows will be different. Skip this match.
if
len
(
paired_msa_rows
)
==
len
(
accession_id_match
):
paired_and_non_paired_msa_rows
=
np
.
array
([
-
1
]
*
num_examples
)
matched_chain_indices
=
list
(
match_index_to_chain_index
.
values
())
paired_and_non_paired_msa_rows
[
matched_chain_indices
]
=
paired_msa_rows
all_paired_msa_rows
.
append
(
list
(
paired_and_non_paired_msa_rows
))
return
all_paired_msa_rows
def
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
]
def
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
:
List
[
pd
.
DataFrame
]
)
->
List
[
List
[
int
]]:
)
->
List
[
List
[
int
]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
"""Finds MSA sequence pairings across chains based on sequence similarity.
...
@@ -324,8 +178,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
...
@@ -324,8 +178,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
return
all_paired_msa_rows
return
all_paired_msa_rows
def
pair_sequences
(
examples
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
def
pair_sequences
(
prokaryotic
:
bool
)
->
Dict
[
int
,
np
.
ndarray
]:
examples
:
List
[
Mapping
[
str
,
np
.
ndarray
]],
)
->
Dict
[
int
,
np
.
ndarray
]:
"""Returns indices for paired MSA sequences across chains."""
"""Returns indices for paired MSA sequences across chains."""
num_examples
=
len
(
examples
)
num_examples
=
len
(
examples
)
...
@@ -367,23 +222,7 @@ def pair_sequences(examples: List[Mapping[str, np.ndarray]],
...
@@ -367,23 +222,7 @@ def pair_sequences(examples: List[Mapping[str, np.ndarray]],
isinstance
(
species_df
,
pd
.
DataFrame
)])
>
600
):
isinstance
(
species_df
,
pd
.
DataFrame
)])
>
600
):
continue
continue
# In prokaryotes (and some eukaryotes), interacting genes are often
paired_msa_rows
=
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
)
# co-located on the chromosome into operons. Because of that we can assume
# that if two proteins' intergenic distance is less than a threshold, they
# two proteins will form an an interacting pair.
# In most eukaryotes, a single protein's MSA can contain many paralogs.
# Two genes may interact even if they are not close by genomic distance.
# In case of eukaryotes, some methods pair MSA sequences using sequence
# similarity method.
# See Jinbo Xu's work:
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
if
prokaryotic
:
paired_msa_rows
=
_match_rows_by_genetic_distance
(
this_species_msa_dfs
)
if
not
paired_msa_rows
:
continue
else
:
paired_msa_rows
=
_match_rows_by_sequence_similarity
(
this_species_msa_dfs
)
all_paired_msa_rows
.
extend
(
paired_msa_rows
)
all_paired_msa_rows
.
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
[
species_dfs_present
].
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
[
species_dfs_present
].
extend
(
paired_msa_rows
)
all_paired_msa_rows_dict
=
{
all_paired_msa_rows_dict
=
{
...
@@ -431,48 +270,66 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
...
@@ -431,48 +270,66 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
def
_correct_post_merged_feats
(
def
_correct_post_merged_feats
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_example
:
Mapping
[
str
,
np
.
ndarray
],
np_chains_list
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
np_chains_list
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
pair_msa_sequences
:
bool
)
->
Mapping
[
str
,
np
.
ndarray
]:
pair_msa_sequences
:
bool
"""Adds features that need to be computed/recomputed post merging."""
)
->
Mapping
[
str
,
np
.
ndarray
]:
"""Adds features that need to be computed/recomputed post merging."""
np_example
[
'seq_length'
]
=
np
.
asarray
(
np_example
[
'aatype'
].
shape
[
0
],
dtype
=
np
.
int32
)
num_res
=
np_example
[
'aatype'
].
shape
[
0
]
np_example
[
'num_alignments'
]
=
np
.
asarray
(
np_example
[
'msa'
].
shape
[
0
],
np_example
[
'seq_length'
]
=
np
.
asarray
(
dtype
=
np
.
int32
)
[
num_res
]
*
num_res
,
dtype
=
np
.
int32
if
not
pair_msa_sequences
:
)
# Generate a bias that is 1 for the first row of every block in the
np_example
[
'num_alignments'
]
=
np
.
asarray
(
# block diagonal MSA - i.e. make sure the cluster stack always includes
np_example
[
'msa'
].
shape
[
0
],
# the query sequences for each chain (since the first row is the query
dtype
=
np
.
int32
# sequence).
)
cluster_bias_masks
=
[]
for
chain
in
np_chains_list
:
if
not
pair_msa_sequences
:
mask
=
np
.
zeros
(
chain
[
'msa'
].
shape
[
0
])
# Generate a bias that is 1 for the first row of every block in the
mask
[
0
]
=
1
# block diagonal MSA - i.e. make sure the cluster stack always includes
cluster_bias_masks
.
append
(
mask
)
# the query sequences for each chain (since the first row is the query
np_example
[
'cluster_bias_mask'
]
=
np
.
concatenate
(
cluster_bias_masks
)
# sequence).
cluster_bias_masks
=
[]
# Initialize Bert mask with masked out off diagonals.
for
chain
in
np_chains_list
:
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
mask
=
np
.
zeros
(
chain
[
'msa'
].
shape
[
0
])
for
x
in
np_chains_list
]
mask
[
0
]
=
1
cluster_bias_masks
.
append
(
mask
)
np_example
[
'bert_mask'
]
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
np_example
[
'cluster_bias_mask'
]
=
np
.
concatenate
(
cluster_bias_masks
)
else
:
np_example
[
'cluster_bias_mask'
]
=
np
.
zeros
(
np_example
[
'msa'
].
shape
[
0
])
# Initialize Bert mask with masked out off diagonals.
np_example
[
'cluster_bias_mask'
][
0
]
=
1
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
# Initialize Bert mask with masked out off diagonals.
for
x
in
np_chains_list
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
]
x
in
np_chains_list
]
msa_masks_all_seq
=
[
np
.
ones
(
x
[
'msa_all_seq'
].
shape
,
dtype
=
np
.
float32
)
for
np_example
[
'bert_mask'
]
=
block_diag
(
x
in
np_chains_list
]
*
msa_masks
,
pad_value
=
0
)
msa_mask_block_diag
=
block_diag
(
else
:
*
msa_masks
,
pad_value
=
0
)
np_example
[
'cluster_bias_mask'
]
=
np
.
zeros
(
np_example
[
'msa'
].
shape
[
0
])
msa_mask_all_seq
=
np
.
concatenate
(
msa_masks_all_seq
,
axis
=
1
)
np_example
[
'cluster_bias_mask'
][
0
]
=
1
np_example
[
'bert_mask'
]
=
np
.
concatenate
(
[
msa_mask_all_seq
,
msa_mask_block_diag
],
axis
=
0
)
# Initialize Bert mask with masked out off diagonals.
return
np_example
msa_masks
=
[
np
.
ones
(
x
[
'msa'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_masks_all_seq
=
[
np
.
ones
(
x
[
'msa_all_seq'
].
shape
,
dtype
=
np
.
float32
)
for
x
in
np_chains_list
]
msa_mask_block_diag
=
block_diag
(
*
msa_masks
,
pad_value
=
0
)
msa_mask_all_seq
=
np
.
concatenate
(
msa_masks_all_seq
,
axis
=
1
)
np_example
[
'bert_mask'
]
=
np
.
concatenate
(
[
msa_mask_all_seq
,
msa_mask_block_diag
],
axis
=
0
)
return
np_example
def
_pad_templates
(
chains
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
def
_pad_templates
(
chains
:
Sequence
[
Mapping
[
str
,
np
.
ndarray
]],
...
...
openfold/data/parsers.py
View file @
4bd1b4d5
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
"""Functions for parsing various file formats."""
"""Functions for parsing various file formats."""
import
collections
import
collections
import
dataclasses
import
dataclasses
import
itertools
import
re
import
re
import
string
import
string
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Set
...
@@ -29,8 +30,7 @@ class Msa:
...
@@ -29,8 +30,7 @@ class Msa:
"""Class representing a parsed MSA file"""
"""Class representing a parsed MSA file"""
sequences
:
Sequence
[
str
]
sequences
:
Sequence
[
str
]
deletion_matrix
:
DeletionMatrix
deletion_matrix
:
DeletionMatrix
descriptions
:
Sequence
[
str
]
descriptions
:
Optional
[
Sequence
[
str
]]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
(
not
(
if
(
not
(
...
@@ -173,7 +173,7 @@ def parse_a3m(a3m_string: str) -> Msa:
...
@@ -173,7 +173,7 @@ def parse_a3m(a3m_string: str) -> Msa:
at `deletion_matrix[i][j]` is the number of residues deleted from
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
the aligned sequence i at residue position j.
"""
"""
sequences
,
descriptions
=
parse_fasta
(
a3m_string
)
sequences
,
descriptions
=
parse_fasta
(
a3m_string
)
deletion_matrix
=
[]
deletion_matrix
=
[]
for
msa_sequence
in
sequences
:
for
msa_sequence
in
sequences
:
deletion_vec
=
[]
deletion_vec
=
[]
...
@@ -642,3 +642,20 @@ def parse_hmmsearch_a3m(
...
@@ -642,3 +642,20 @@ def parse_hmmsearch_a3m(
hits
.
append
(
hit
)
hits
.
append
(
hit
)
return
hits
return
hits
def
parse_hmmsearch_sto
(
output_string
:
str
,
input_sequence
:
str
)
->
Sequence
[
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
convert_stockholm_to_a3m
(
output_string
,
remove_first_row_gaps
=
False
)
template_hits
=
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
return
template_hits
openfold/data/templates.py
View file @
4bd1b4d5
...
@@ -220,13 +220,6 @@ def _assess_hhsearch_hit(
...
@@ -220,13 +220,6 @@ def _assess_hhsearch_hit(
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
template_sequence
=
hit
.
hit_sequence
.
replace
(
"-"
,
""
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
length_ratio
=
float
(
len
(
template_sequence
))
/
len
(
query_sequence
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
if
_is_after_cutoff
(
hit_pdb_code
,
release_dates
,
release_date_cutoff
):
date
=
release_dates
[
hit_pdb_code
.
upper
()]
date
=
release_dates
[
hit_pdb_code
.
upper
()]
raise
DateError
(
raise
DateError
(
...
@@ -240,6 +233,13 @@ def _assess_hhsearch_hit(
...
@@ -240,6 +233,13 @@ def _assess_hhsearch_hit(
f
"Align ratio:
{
align_ratio
}
."
f
"Align ratio:
{
align_ratio
}
."
)
)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate
=
(
template_sequence
in
query_sequence
and
length_ratio
>
max_subsequence_ratio
)
if
duplicate
:
if
duplicate
:
raise
DuplicateError
(
raise
DuplicateError
(
"Template is an exact subsequence of query with large "
"Template is an exact subsequence of query with large "
...
@@ -770,7 +770,7 @@ def _prefilter_hit(
...
@@ -770,7 +770,7 @@ def _prefilter_hit(
except
PrefilterError
as
e
:
except
PrefilterError
as
e
:
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
hit_name
=
f
"
{
hit_pdb_code
}
_
{
hit_chain_id
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
msg
=
f
"hit
{
hit_name
}
did not pass prefilter:
{
str
(
e
)
}
"
logging
.
info
(
"%s: %s"
,
query_pdb_code
,
msg
)
logging
.
info
(
msg
)
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
if
strict_error_check
and
isinstance
(
e
,
(
DateError
,
DuplicateError
)):
# In strict mode we treat some prefilter cases as errors.
# In strict mode we treat some prefilter cases as errors.
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
return
PrefilterResult
(
valid
=
False
,
error
=
msg
,
warning
=
None
)
...
@@ -826,6 +826,7 @@ def _process_single_hit(
...
@@ -826,6 +826,7 @@ def _process_single_hit(
query_sequence
,
query_sequence
,
template_sequence
,
template_sequence
,
)
)
# Fail if we can't find the mmCIF file.
# Fail if we can't find the mmCIF file.
cif_string
=
_read_file
(
cif_path
)
cif_string
=
_read_file
(
cif_path
)
...
@@ -968,7 +969,7 @@ class TemplateHitFeaturizer(abc.ABC):
...
@@ -968,7 +969,7 @@ class TemplateHitFeaturizer(abc.ABC):
raise
ValueError
(
raise
ValueError
(
"max_template_date must be set and have format YYYY-MM-DD."
"max_template_date must be set and have format YYYY-MM-DD."
)
)
self
.
max_hits
=
max_hits
self
.
_
max_hits
=
max_hits
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_kalign_binary_path
=
kalign_binary_path
self
.
_strict_error_check
=
strict_error_check
self
.
_strict_error_check
=
strict_error_check
...
@@ -997,33 +998,23 @@ class TemplateHitFeaturizer(abc.ABC):
...
@@ -997,33 +998,23 @@ class TemplateHitFeaturizer(abc.ABC):
query_sequence
:
str
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
""" Computes the templates for a given query sequence """
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
class
HhsearchHitFeaturizer
(
TemplateHitFeaturizer
):
def
get_templates
(
def
get_templates
(
self
,
self
,
query_sequence
:
str
,
query_sequence
:
str
,
query_release_date
:
Optional
[
datetime
.
datetime
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
hits
:
Sequence
[
parsers
.
TemplateHit
],
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
"""Computes the templates for given query sequence (more details above)."""
"""Computes the templates for given query sequence (more details above)."""
logging
.
info
(
"Searching for template for: %s"
,
query_
pdb_cod
e
)
logging
.
info
(
"Searching for template for: %s"
,
query_
sequenc
e
)
template_features
=
{}
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
template_features
[
template_feature_name
]
=
[]
# Always use a max_template_date. Set to query_release_date minus 60 days
already_seen
=
set
()
# if that's earlier.
template_cutoff_date
=
self
.
_max_template_date
if
query_release_date
:
delta
=
datetime
.
timedelta
(
days
=
60
)
if
query_release_date
-
delta
<
template_cutoff_date
:
template_cutoff_date
=
query_release_date
-
delta
assert
template_cutoff_date
<
query_release_date
assert
template_cutoff_date
<=
self
.
_max_template_date
num_hits
=
0
errors
=
[]
errors
=
[]
warnings
=
[]
warnings
=
[]
...
@@ -1032,7 +1023,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1032,7 +1023,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
prefilter_result
=
_prefilter_hit
(
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
hit
=
hit
,
hit
=
hit
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
...
@@ -1057,17 +1048,16 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1057,17 +1048,16 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for
i
in
idx
:
for
i
in
idx
:
# We got all the templates we wanted, stop processing hits.
# We got all the templates we wanted, stop processing hits.
if
num_hits
>=
self
.
max_hits
:
if
len
(
already_seen
)
>=
self
.
max_hits
:
break
break
hit
=
filtered
[
i
]
hit
=
filtered
[
i
]
result
=
_process_single_hit
(
result
=
_process_single_hit
(
query_sequence
=
query_sequence
,
query_sequence
=
query_sequence
,
query_pdb_code
=
query_pdb_code
,
hit
=
hit
,
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
template_
cutoff_
date
,
max_template_date
=
self
.
_max_
template_date
,
release_dates
=
self
.
_release_dates
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
strict_error_check
=
self
.
_strict_error_check
,
...
@@ -1091,8 +1081,10 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1091,8 +1081,10 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
result
.
warning
,
result
.
warning
,
)
)
else
:
else
:
# Increment the hit counter, since we got features out of this hit.
already_seen_key
=
result
.
features
[
"template_sequence"
]
num_hits
+=
1
if
(
already_seen_key
in
already_seen
):
continue
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
template_features
[
k
].
append
(
result
.
features
[
k
])
...
@@ -1118,6 +1110,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1118,6 +1110,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
query_sequence
:
str
,
query_sequence
:
str
,
hits
:
Sequence
[
parsers
.
TemplateHit
]
hits
:
Sequence
[
parsers
.
TemplateHit
]
)
->
TemplateSearchResult
:
)
->
TemplateSearchResult
:
logging
.
info
(
"Searching for template for: %s"
,
query_sequence
)
template_features
=
{}
template_features
=
{}
for
template_feature_name
in
TEMPLATE_FEATURES
:
for
template_feature_name
in
TEMPLATE_FEATURES
:
template_features
[
template_feature_name
]
=
[]
template_features
[
template_feature_name
]
=
[]
...
@@ -1126,45 +1120,73 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
...
@@ -1126,45 +1120,73 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
errors
=
[]
errors
=
[]
warnings
=
[]
warnings
=
[]
if
not
hits
or
hits
[
0
].
sum_probs
is
None
:
# DISCREPANCY: This filtering scheme that saves time
sorted_hits
=
hits
filtered
=
[]
else
:
for
hit
in
hits
:
sorted_hits
=
sorted
(
hits
,
key
=
lambda
x
:
x
.
sum_probs
,
reverse
=
True
)
prefilter_result
=
_prefilter_hit
(
query_sequence
=
query_sequence
,
hit
=
hit
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
)
for
hit
in
sorted_hits
:
if
prefilter_result
.
error
:
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
errors
.
append
(
prefilter_result
.
error
)
break
result
=
_process_single_hit
(
if
prefilter_result
.
warning
:
query_sequence
=
query_sequence
,
warnings
.
append
(
prefilter_result
.
warning
)
hit
=
hit
,
mmcif_dir
=
self
.
_mmcif_dir
,
if
prefilter_result
.
valid
:
max_template_date
=
self
.
_max_template_date
,
filtered
.
append
(
hit
)
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
filtered
=
list
(
strict_error_check
=
self
.
_strict_error_check
,
sorted
(
kalign_binary_path
=
self
.
_kalign_binary_path
filtered
,
key
=
lambda
x
:
x
.
sum_probs
if
x
.
sum_probs
else
0.
,
reverse
=
True
)
)
)
idx
=
list
(
range
(
len
(
filtered
)))
if
(
self
.
_shuffle_top_k_prefiltered
):
stk
=
self
.
_shuffle_top_k_prefiltered
idx
[:
stk
]
=
np
.
random
.
permutation
(
idx
[:
stk
])
if
result
.
error
:
for
i
in
idx
:
errors
.
append
(
result
.
error
)
if
(
len
(
already_seen
)
>=
self
.
_max_hits
):
break
if
result
.
warning
:
hit
=
filtered
[
i
]
warnings
.
append
(
result
.
warning
)
if
result
.
features
is
None
:
result
=
_process_single_hit
(
logging
.
debug
(
query_sequence
=
query_sequence
,
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
=
hit
,
hit
.
name
,
result
.
error
,
result
.
warning
,
mmcif_dir
=
self
.
_mmcif_dir
,
max_template_date
=
self
.
_max_template_date
,
release_dates
=
self
.
_release_dates
,
obsolete_pdbs
=
self
.
_obsolete_pdbs
,
strict_error_check
=
self
.
_strict_error_check
,
kalign_binary_path
=
self
.
_kalign_binary_path
)
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
result
.
error
:
if
(
already_seen_key
in
already_seen
):
errors
.
append
(
result
.
error
)
continue
# Increment the hit counter, since we got features out of this hit.
if
result
.
warning
:
already_seen
.
add
(
already_seen_key
)
warnings
.
append
(
result
.
warning
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
result
.
features
is
None
:
logging
.
debug
(
"Skipped invalid hit %s, error: %s, warning: %s"
,
hit
.
name
,
result
.
error
,
result
.
warning
,
)
else
:
already_seen_key
=
result
.
features
[
"template_sequence"
]
if
(
already_seen_key
in
already_seen
):
continue
# Increment the hit counter, since we got features out of this hit.
already_seen
.
add
(
already_seen_key
)
for
k
in
template_features
:
template_features
[
k
].
append
(
result
.
features
[
k
])
if
already_seen
:
if
already_seen
:
for
name
in
template_features
:
for
name
in
template_features
:
...
...
openfold/data/tools/hhsearch.py
View file @
4bd1b4d5
...
@@ -18,7 +18,7 @@ import glob
...
@@ -18,7 +18,7 @@ import glob
import
logging
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
typing
import
Sequence
from
typing
import
Sequence
,
Optional
from
openfold.data
import
parsers
from
openfold.data
import
parsers
from
openfold.data.tools
import
utils
from
openfold.data.tools
import
utils
...
@@ -71,11 +71,12 @@ class HHSearch:
...
@@ -71,11 +71,12 @@ class HHSearch:
def
input_format
(
self
)
->
str
:
def
input_format
(
self
)
->
str
:
return
'a3m'
return
'a3m'
def
query
(
self
,
a3m
:
str
)
->
str
:
def
query
(
self
,
a3m
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"query.a3m"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"output.hhr"
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
hhr_path
=
os
.
path
.
join
(
output_dir
,
"hhsearch_output.hhr"
)
with
open
(
input_path
,
"w"
)
as
f
:
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
f
.
write
(
a3m
)
...
@@ -114,7 +115,8 @@ class HHSearch:
...
@@ -114,7 +115,8 @@ class HHSearch:
hhr
=
f
.
read
()
hhr
=
f
.
read
()
return
hhr
return
hhr
def
get_template_hits
(
self
,
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
output_string
:
str
,
input_sequence
:
str
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
)
->
Sequence
[
parsers
.
TemplateHit
]:
...
...
openfold/data/tools/hmmsearch.py
View file @
4bd1b4d5
...
@@ -28,11 +28,12 @@ class Hmmsearch(object):
...
@@ -28,11 +28,12 @@ class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""
"""Python wrapper of the hmmsearch binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
*
,
binary_path
:
str
,
binary_path
:
str
,
hmmbuild_binary_path
:
str
,
hmmbuild_binary_path
:
str
,
database_path
:
str
,
database_path
:
str
,
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
flags
:
Optional
[
Sequence
[
str
]]
=
None
):
"""Initializes the Python hmmsearch wrapper.
"""Initializes the Python hmmsearch wrapper.
Args:
Args:
...
@@ -71,17 +72,23 @@ class Hmmsearch(object):
...
@@ -71,17 +72,23 @@ class Hmmsearch(object):
def
input_format
(
self
)
->
str
:
def
input_format
(
self
)
->
str
:
return
'sto'
return
'sto'
def
query
(
self
,
msa_sto
:
str
)
->
str
:
def
query
(
self
,
msa_sto
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given stockholm msa."""
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
msa_sto
,
hmm
=
self
.
hmmbuild_runner
.
build_profile_from_sto
(
model_construction
=
'hand'
)
msa_sto
,
return
self
.
query_with_hmm
(
hmm
)
model_construction
=
'hand'
)
return
self
.
query_with_hmm
(
hmm
,
output_dir
)
def
query_with_hmm
(
self
,
hmm
:
str
)
->
str
:
def
query_with_hmm
(
self
,
hmm
:
str
,
output_dir
:
Optional
[
str
]
=
None
)
->
str
:
"""Queries the database using hmmsearch using a given hmm."""
"""Queries the database using hmmsearch using a given hmm."""
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
with
utils
.
tmpdir_manager
()
as
query_tmp_dir
:
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
hmm_input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'query.hmm'
)
out_path
=
os
.
path
.
join
(
query_tmp_dir
,
'output.sto'
)
output_dir
=
query_tmp_dir
if
output_dir
is
None
else
output_dir
out_path
=
os
.
path
.
join
(
output_dir
,
'hmm_output.sto'
)
with
open
(
hmm_input_path
,
'w'
)
as
f
:
with
open
(
hmm_input_path
,
'w'
)
as
f
:
f
.
write
(
hmm
)
f
.
write
(
hmm
)
...
@@ -117,18 +124,14 @@ class Hmmsearch(object):
...
@@ -117,18 +124,14 @@ class Hmmsearch(object):
return
out_msa
return
out_msa
def
get_template_hits
(
self
,
@
staticmethod
def
get_template_hits
(
output_string
:
str
,
output_string
:
str
,
input_sequence
:
str
input_sequence
:
str
)
->
Sequence
[
parsers
.
TemplateHit
]:
)
->
Sequence
[
parsers
.
TemplateHit
]:
"""Gets parsed template hits from the raw string output by the tool."""
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string
=
parsers
.
convert_stockholm_to_a3m
(
template_hits
=
parsers
.
parse_hmmsearch_sto
(
output_string
,
output_string
,
remove_first_row_gaps
=
False
input_sequence
,
)
template_hits
=
parsers
.
parse_hmmsearch_a3m
(
query_sequence
=
input_sequence
,
a3m_string
=
a3m_string
,
skip_first
=
False
)
)
return
template_hits
return
template_hits
openfold/model/embedders.py
View file @
4bd1b4d5
...
@@ -13,12 +13,26 @@
...
@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
from
openfold.utils
import
all_atom_multimer
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
dgram_from_positions
,
build_template_angle_feat
,
build_template_pair_feat
,
)
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
one_hot
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
openfold.utils
import
geometry
from
openfold.utils.tensor_utils
import
one_hot
,
tensor_tree_map
,
dict_multimap
class
InputEmbedder
(
nn
.
Module
):
class
InputEmbedder
(
nn
.
Module
):
...
@@ -85,20 +99,16 @@ class InputEmbedder(nn.Module):
...
@@ -85,20 +99,16 @@ class InputEmbedder(nn.Module):
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
def
forward
(
self
,
batch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
,
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
tf:
batch: Dict containing
"target_feat" features of shape [*, N_res, tf_dim]
"target_feat":
ri:
Features of shape [*, N_res, tf_dim]
"residue_index" features of shape [*, N_res]
"residue_index":
msa:
Features of shape [*, N_res]
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
"msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns:
Returns:
msa_emb:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
[*, N_clust, N_res, C_m] MSA embedding
...
@@ -106,6 +116,10 @@ class InputEmbedder(nn.Module):
...
@@ -106,6 +116,10 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
"""
"""
tf
=
batch
[
"target_feat"
]
ri
=
batch
[
"residue_index"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
...
@@ -126,6 +140,154 @@ class InputEmbedder(nn.Module):
...
@@ -126,6 +140,154 @@ class InputEmbedder(nn.Module):
return
msa_emb
,
pair_emb
return
msa_emb
,
pair_emb
class
InputEmbedderMultimer
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
max_relative_idx
:
int
,
use_chain_relative
:
bool
,
max_relative_chain
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedderMultimer
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
max_relative_idx
=
max_relative_idx
self
.
use_chain_relative
=
use_chain_relative
self
.
max_relative_chain
=
max_relative_chain
if
(
self
.
use_chain_relative
):
self
.
no_bins
=
(
2
*
max_relative_idx
+
2
+
1
+
2
*
max_relative_chain
+
2
)
else
:
self
.
no_bins
=
2
*
max_relative_idx
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
batch
):
pos
=
batch
[
"residue_index"
]
asym_id
=
batch
[
"asym_id"
]
asym_id_same
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:])
offset
=
pos
[...,
None
]
-
pos
[...,
None
,
:]
clipped_offset
=
torch
.
clamp
(
offset
+
self
.
max_relative_idx
,
0
,
2
*
self
.
max_relative_idx
)
rel_feats
=
[]
if
(
self
.
use_chain_relative
):
final_offset
=
torch
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
)
)
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
final_offset
,
2
*
self
.
max_relative_idx
+
2
,
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
(
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:])
rel_feats
.
append
(
entity_id_same
[...,
None
])
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
max_rel_chain
=
self
.
max_relative_chain
clipped_rel_chain
=
torch
.
clamp
(
rel_sym_id
+
max_rel_chain
,
0
,
2
*
max_rel_chain
,
)
final_rel_chain
=
torch
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
torch
.
ones_like
(
clipped_rel_chain
)
)
rel_chain
=
torch
.
nn
.
functional
.
one_hot
(
final_rel_chain
,
2
*
max_rel_chain
+
2
,
)
rel_feats
.
append
(
rel_chain
)
else
:
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
clipped_offset
,
2
*
self
.
max_relative_idx
+
1
,
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
torch
.
cat
(
rel_feats
,
dim
=-
1
).
to
(
self
.
linear_relpos
.
weight
.
dtype
)
return
self
.
linear_relpos
(
rel_feat
)
def
forward
(
self
,
batch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tf
=
batch
[
"target_feat"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
pair_emb
+
self
.
relpos
(
batch
)
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
RecyclingEmbedder
(
nn
.
Module
):
class
RecyclingEmbedder
(
nn
.
Module
):
"""
"""
Embeds the output of an iteration of the model for recycling.
Embeds the output of an iteration of the model for recycling.
...
@@ -312,6 +474,102 @@ class TemplatePairEmbedder(nn.Module):
...
@@ -312,6 +474,102 @@ class TemplatePairEmbedder(nn.Module):
return
x
return
x
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
):
super
().
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
,
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_pair_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
class
ExtraMSAEmbedder
(
nn
.
Module
):
class
ExtraMSAEmbedder
(
nn
.
Module
):
"""
"""
Embeds unclustered MSA sequences.
Embeds unclustered MSA sequences.
...
@@ -350,3 +608,315 @@ class ExtraMSAEmbedder(nn.Module):
...
@@ -350,3 +608,315 @@ class ExtraMSAEmbedder(nn.Module):
x
=
self
.
linear
(
x
)
x
=
self
.
linear
(
x
)
return
x
return
x
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
embed_angles
:
ret
[
"template_single_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_z
:
int
,
c_out
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
):
super
().
__init__
()
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_z
)
self
.
query_embedding_linear
=
Linear
(
c_z
,
c_out
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
)
self
.
x_linear
=
Linear
(
1
,
c_out
)
self
.
y_linear
=
Linear
(
1
,
c_out
)
self
.
z_linear
=
Linear
(
1
,
c_out
)
self
.
backbone_mask_linear
=
Linear
(
1
,
c_out
)
def
forward
(
self
,
template_dgram
:
torch
.
Tensor
,
aatype_one_hot
:
torch
.
Tensor
,
query_embedding
:
torch
.
Tensor
,
pseudo_beta_mask
:
torch
.
Tensor
,
backbone_mask
:
torch
.
Tensor
,
multichain_mask_2d
:
torch
.
Tensor
,
unit_vector
:
geometry
.
Vec3Array
,
)
->
torch
.
Tensor
:
act
=
0.
pseudo_beta_mask_2d
=
(
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
)
pseudo_beta_mask_2d
*=
multichain_mask_2d
template_dgram
*=
pseudo_beta_mask_2d
[...,
None
]
act
+=
self
.
dgram_linear
(
template_dgram
)
act
+=
self
.
pseudo_beta_mask_linear
(
pseudo_beta_mask_2d
[...,
None
])
aatype_one_hot
=
aatype_one_hot
.
to
(
template_dgram
.
dtype
)
act
+=
self
.
aatype_linear_1
(
aatype_one_hot
[...,
None
,
:,
:])
act
+=
self
.
aatype_linear_2
(
aatype_one_hot
[...,
None
,
:])
backbone_mask_2d
=
(
backbone_mask
[...,
None
]
*
backbone_mask
[...,
None
,
:]
)
backbone_mask_2d
*=
multichain_mask_2d
x
,
y
,
z
=
[
coord
*
backbone_mask_2d
for
coord
in
unit_vector
]
act
+=
self
.
x_linear
(
x
[...,
None
])
act
+=
self
.
y_linear
(
y
[...,
None
])
act
+=
self
.
z_linear
(
z
[...,
None
])
act
+=
self
.
backbone_mask_linear
(
backbone_mask_2d
[...,
None
])
query_embedding
=
self
.
query_embedding_layer_norm
(
query_embedding
)
act
+=
self
.
query_embedding_linear
(
query_embedding
)
return
act
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_in
:
int
,
c_m
:
int
,
):
super
().
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_m
)
self
.
template_projector
=
Linear
(
c_m
,
c_m
)
def
forward
(
self
,
batch
,
atom_pos
,
aatype_one_hot
,
):
out
=
{}
template_chi_angles
,
template_chi_mask
=
(
all_atom_multimer
.
compute_chi_angles
(
atom_pos
,
batch
[
"template_all_atom_mask"
],
batch
[
"template_aatype"
],
)
)
template_features
=
torch
.
cat
(
[
aatype_one_hot
,
torch
.
sin
(
template_chi_angles
)
*
template_chi_mask
,
torch
.
cos
(
template_chi_angles
)
*
template_chi_mask
,
template_chi_mask
,
],
dim
=-
1
,
)
template_mask
=
template_chi_mask
[...,
0
]
template_activations
=
self
.
template_single_embedder
(
template_features
)
template_activations
=
torch
.
nn
.
functional
.
relu
(
template_activations
)
template_activations
=
self
.
template_projector
(
template_activations
,
)
out
[
"template_single_embedding"
]
=
(
template_activations
)
out
[
"template_mask"
]
=
template_mask
return
out
class
TemplateEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedderMultimer
,
self
).
__init__
()
self
.
config
=
config
self
.
template_pair_embedder
=
TemplatePairEmbedderMultimer
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_single_embedder
=
TemplateSingleEmbedderMultimer
(
**
config
[
"template_single_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
linear_t
=
Linear
(
config
.
c_t
,
config
.
c_z
)
def
forward
(
self
,
batch
,
z
,
padding_mask_2d
,
templ_dim
,
chunk_size
,
multichain_mask_2d
,
):
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
act
=
0.
template_positions
,
pseudo_beta_mask
=
(
single_template_feats
[
"template_pseudo_beta"
],
single_template_feats
[
"template_pseudo_beta_mask"
],
)
template_dgram
=
dgram_from_positions
(
template_positions
,
inf
=
self
.
config
.
inf
,
**
self
.
config
.
distogram
,
)
aatype_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
single_template_feats
[
"template_aatype"
],
22
,
)
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_tensor
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
single_template_feats
[
"template_aatype"
],
)
points
=
rigid
.
translation
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
pair_act
=
self
.
template_pair_embedder
(
template_dgram
,
aatype_one_hot
,
z
,
pseudo_beta_mask
,
backbone_mask
,
multichain_mask_2d
,
unit_vector
,
)
single_template_embeds
[
"template_pair_embedding"
]
=
pair_act
single_template_embeds
.
update
(
self
.
template_single_embedder
(
single_template_feats
,
atom_pos
,
aatype_one_hot
,
)
)
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
t
=
torch
.
nn
.
functional
.
relu
(
t
)
t
=
self
.
linear_t
(
t
)
template_embeds
[
"template_pair_embedding"
]
=
t
return
template_embeds
openfold/model/model.py
View file @
4bd1b4d5
...
@@ -17,28 +17,25 @@ from functools import partial
...
@@ -17,28 +17,25 @@ from functools import partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.data
import
data_transforms_multimer
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
pseudo_beta_fn
,
build_extra_msa_feat
,
build_extra_msa_feat
,
build_template_angle_feat
,
dgram_from_positions
,
build_template_pair_feat
,
atom14_to_atom37
,
atom14_to_atom37
,
)
)
from
openfold.model.embedders
import
(
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
RecyclingEmbedder
,
Template
Angle
Embedder
,
TemplateEmbedder
,
Template
Pair
Embedder
,
TemplateEmbedder
Multimer
,
ExtraMSAEmbedder
,
ExtraMSAEmbedder
,
)
)
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.evoformer
import
EvoformerStack
,
ExtraMSAStack
from
openfold.model.heads
import
AuxiliaryHeads
from
openfold.model.heads
import
AuxiliaryHeads
import
openfold.np.residue_constants
as
residue_constants
import
openfold.np.residue_constants
as
residue_constants
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.structure_module
import
StructureModule
from
openfold.model.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
openfold.utils.loss
import
(
from
openfold.utils.loss
import
(
compute_plddt
,
compute_plddt
,
)
)
...
@@ -69,24 +66,28 @@ class AlphaFold(nn.Module):
...
@@ -69,24 +66,28 @@ class AlphaFold(nn.Module):
extra_msa_config
=
config
.
extra_msa
extra_msa_config
=
config
.
extra_msa
# Main trunk + structure module
# Main trunk + structure module
self
.
input_embedder
=
InputEmbedder
(
if
(
self
.
globals
.
is_multimer
):
**
config
[
"input_embedder"
],
self
.
input_embedder
=
InputEmbedderMultimer
(
)
**
config
[
"input_embedder"
],
)
else
:
self
.
input_embedder
=
InputEmbedder
(
**
config
[
"input_embedder"
],
)
self
.
recycling_embedder
=
RecyclingEmbedder
(
self
.
recycling_embedder
=
RecyclingEmbedder
(
**
config
[
"recycling_embedder"
],
**
config
[
"recycling_embedder"
],
)
)
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
template_config
[
"template_angle_embedder"
],
if
(
self
.
globals
.
is_multimer
):
)
self
.
template_embedder
=
TemplateEmbedderMultimer
(
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
template_config
,
**
template_config
[
"template_pair_embedder"
],
)
)
else
:
self
.
template_pair_stack
=
TemplatePairStack
(
self
.
template_embedder
=
TemplateEmbedder
(
**
template_config
[
"template_pair_stack"
],
template_config
,
)
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
template_config
[
"template_pointwise_attention"
],
)
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
self
.
extra_msa_embedder
=
ExtraMSAEmbedder
(
**
extra_msa_config
[
"extra_msa_embedder"
],
**
extra_msa_config
[
"extra_msa_embedder"
],
)
)
...
@@ -96,7 +97,9 @@ class AlphaFold(nn.Module):
...
@@ -96,7 +97,9 @@ class AlphaFold(nn.Module):
self
.
evoformer
=
EvoformerStack
(
self
.
evoformer
=
EvoformerStack
(
**
config
[
"evoformer_stack"
],
**
config
[
"evoformer_stack"
],
)
)
self
.
structure_module
=
StructureModule
(
self
.
structure_module
=
StructureModule
(
is_multimer
=
self
.
globals
.
is_multimer
,
**
config
[
"structure_module"
],
**
config
[
"structure_module"
],
)
)
...
@@ -106,71 +109,6 @@ class AlphaFold(nn.Module):
...
@@ -106,71 +109,6 @@ class AlphaFold(nn.Module):
self
.
config
=
config
self
.
config
=
config
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
single_template_feats
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
,
).
to
(
z
.
dtype
)
t
=
self
.
template_pair_embedder
(
t
)
single_template_embeds
.
update
({
"pair"
:
t
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, S_t, N, N, C_z]
t
=
self
.
template_pair_stack
(
template_embeds
[
"pair"
],
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
ret
=
{}
if
self
.
config
.
template
.
embed_angles
:
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
})
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
,
_recycle
=
True
):
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -197,11 +135,7 @@ class AlphaFold(nn.Module):
...
@@ -197,11 +135,7 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m]
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
m
,
z
=
self
.
input_embedder
(
m
,
z
=
self
.
input_embedder
(
feats
)
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
)
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
if
None
in
[
m_1_prev
,
z_prev
,
x_prev
]:
...
@@ -257,40 +191,74 @@ class AlphaFold(nn.Module):
...
@@ -257,40 +191,74 @@ class AlphaFold(nn.Module):
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
}
template_embeds
=
self
.
embed_templates
(
template_feats
,
if
(
self
.
globals
.
is_multimer
):
z
,
asym_id
=
feats
[
"asym_id"
]
pair_mask
.
to
(
dtype
=
z
.
dtype
),
multichain_mask_2d
=
(
no_batch_dims
,
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
)
)
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
)
else
:
template_embeds
=
self
.
template_embedder
(
template_feats
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
self
.
globals
.
chunk_size
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
self
.
config
.
template
.
embed_angles
:
if
(
self
.
config
.
template
.
embed_angles
or
(
self
.
globals
.
is_multimer
and
self
.
config
.
template
.
enabled
)
):
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_
a
ngle_embedding"
]],
[
m
,
template_embeds
[
"template_
si
ngle_embedding"
]],
dim
=-
3
dim
=-
3
)
)
# [*, S, N]
# [*, S, N]
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
if
(
not
self
.
globals
.
is_multimer
):
msa_mask
=
torch
.
cat
(
torsion_angles_mask
=
feats
[
"template_torsion_angles_mask"
]
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
msa_mask
=
torch
.
cat
(
dim
=-
2
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
)
dim
=-
2
)
else
:
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
template_embeds
[
"template_mask"
]],
dim
=-
2
,
)
# Embed extra MSA features + merge with pairwise embeddings
# Embed extra MSA features + merge with pairwise embeddings
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
if
(
self
.
globals
.
is_multimer
):
extra_msa_fn
=
data_transforms_multimer
.
build_extra_msa_feat
else
:
extra_msa_fn
=
build_extra_msa_feat
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
extra_msa_feat
=
extra_msa_fn
(
feats
)
extra_msa_feat
=
self
.
extra_msa_embedder
(
extra_msa_feat
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
(
a
,
extra_msa_feat
,
z
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
extra_msa_feat
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
...
@@ -340,14 +308,14 @@ class AlphaFold(nn.Module):
...
@@ -340,14 +308,14 @@ class AlphaFold(nn.Module):
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_
embedder
.
template_
pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
for
b
in
self
.
extra_msa_stack
.
blocks
:
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
False
b
.
ckpt
=
False
def
_enable_activation_checkpointing
(
self
):
def
_enable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
template_
embedder
.
template_
pair_stack
.
blocks_per_ckpt
=
(
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
self
.
config
.
template
.
template_pair_stack
.
blocks_per_ckpt
)
)
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
evoformer
.
blocks_per_ckpt
=
(
...
...
openfold/model/structure_module.py
View file @
4bd1b4d5
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
...
@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom14_rigid_group_positions
,
)
)
from
openfold.utils.geometry.quat_rigid
import
QuatRigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.vector
import
Vec3Array
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
frames_and_literature_positions_to_atom14_pos
,
frames_and_literature_positions_to_atom14_pos
,
torsion_angles_to_frames
,
torsion_angles_to_frames
,
...
@@ -155,14 +158,14 @@ class PointProjection(nn.Module):
...
@@ -155,14 +158,14 @@ class PointProjection(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
c_hidden
:
int
,
c_hidden
:
int
,
num_points
:
int
,
num_points
:
int
,
no_heads
:
int
no_heads
:
int
,
return_local_points
:
bool
=
False
,
return_local_points
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
return_local_points
=
return_local_points
self
.
return_local_points
=
return_local_points
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
linear
=
Linear
(
c_hidden
,
3
*
num_points
)
self
.
linear
=
Linear
(
c_hidden
,
no_heads
*
3
*
num_points
)
def
forward
(
self
,
def
forward
(
self
,
activations
:
torch
.
Tensor
,
activations
:
torch
.
Tensor
,
...
@@ -171,11 +174,13 @@ class PointProjection(nn.Module):
...
@@ -171,11 +174,13 @@ class PointProjection(nn.Module):
# TODO: Needs to run in high precision during training
# TODO: Needs to run in high precision during training
points_local
=
self
.
linear
(
activations
)
points_local
=
self
.
linear
(
activations
)
points_local
=
points_local
.
reshape
(
points_local
=
points_local
.
reshape
(
points_local
.
shape
[:
-
1
],
*
points_local
.
shape
[:
-
1
],
self
.
no_heads
,
self
.
no_heads
,
-
1
,
-
1
,
)
)
points_local
=
torch
.
split
(
points_local
,
3
,
dim
=-
1
)
points_local
=
torch
.
split
(
points_local
,
points_local
.
shape
[
-
1
]
//
3
,
dim
=-
1
)
points_local
=
Vec3Array
(
*
points_local
)
points_local
=
Vec3Array
(
*
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
points_global
=
rigids
[...,
None
,
None
].
apply_to_point
(
points_local
)
...
@@ -184,7 +189,7 @@ class PointProjection(nn.Module):
...
@@ -184,7 +189,7 @@ class PointProjection(nn.Module):
return
points_global
return
points_global
# WEIGHTS CHANGED
class
InvariantPointAttention
(
nn
.
Module
):
class
InvariantPointAttention
(
nn
.
Module
):
"""
"""
Implements Algorithm 22.
Implements Algorithm 22.
...
@@ -199,6 +204,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -199,6 +204,7 @@ class InvariantPointAttention(nn.Module):
no_v_points
:
int
,
no_v_points
:
int
,
inf
:
float
=
1e5
,
inf
:
float
=
1e5
,
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
is_multimer
:
bool
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -225,14 +231,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -225,14 +231,14 @@ class InvariantPointAttention(nn.Module):
self
.
no_v_points
=
no_v_points
self
.
no_v_points
=
no_v_points
self
.
inf
=
inf
self
.
inf
=
inf
self
.
eps
=
eps
self
.
eps
=
eps
self
.
is_multimer
=
is_multimer
# These linear layers differ from their specifications in the
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Here as in the official source, they have bias and use the default
# Lecun initialization.
# Lecun initialization.
hc
=
self
.
c_hidden
*
self
.
no_heads
hc
=
self
.
c_hidden
*
self
.
no_heads
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
)
self
.
linear_q
=
Linear
(
self
.
c_s
,
hc
,
bias
=
(
not
is_multimer
))
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_q_points
=
PointProjection
(
self
.
linear_q_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
...
@@ -240,17 +246,27 @@ class InvariantPointAttention(nn.Module):
...
@@ -240,17 +246,27 @@ class InvariantPointAttention(nn.Module):
self
.
no_heads
self
.
no_heads
)
)
self
.
linear_k_points
=
PointProjection
(
if
(
is_multimer
):
self
.
c_s
,
self
.
linear_k
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
no_qk_points
self
.
linear_v
=
Linear
(
self
.
c_s
,
hc
,
bias
=
False
)
self
.
no_heads
,
self
.
linear_k_points
=
PointProjection
(
)
self
.
c_s
,
self
.
no_qk_points
,
self
.
no_heads
,
)
self
.
linear_v_points
=
PointProjection
(
self
.
linear_v_points
=
PointProjection
(
self
.
c_s
,
self
.
c_s
,
self
.
no_v_points
self
.
no_v_points
,
self
.
no_heads
,
self
.
no_heads
,
)
)
else
:
self
.
linear_kv
=
Linear
(
self
.
c_s
,
2
*
hc
)
self
.
linear_kv_points
=
PointProjection
(
self
.
c_s
,
self
.
no_qk_points
+
self
.
no_v_points
,
self
.
no_heads
,
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
self
.
linear_b
=
Linear
(
self
.
c_z
,
self
.
no_heads
)
...
@@ -290,25 +306,48 @@ class InvariantPointAttention(nn.Module):
...
@@ -290,25 +306,48 @@ class InvariantPointAttention(nn.Module):
#######################################
#######################################
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
q
=
self
.
linear_q
(
s
)
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, 2 * C_hidden]
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
# [*, N_res, H, P_qk]
# [*, N_res, H, P_qk]
q_pts
=
self
.
linear_q_points
(
s
,
r
)
q_pts
=
self
.
linear_q_points
(
s
,
r
)
# [*, N_res, H, P_qk, 3]
# The following two blocks are equivalent
k_pts
=
self
.
linear_k_points
(
s
,
r
)
# They're separated only to preserve compatibility with old AF weights
if
(
self
.
is_multimer
):
# [*, N_res, H * C_hidden]
k
=
self
.
linear_k
(
s
)
v
=
self
.
linear_v
(
s
)
# [*, N_res, H, C_hidden]
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, P_qk, 3]
k_pts
=
self
.
linear_k_points
(
s
,
r
)
# [*, N_res, H, P_v, 3]
v_pts
=
self
.
linear_v_points
(
s
,
r
)
else
:
# [*, N_res, H * 2 * C_hidden]
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, P_v, 3]
# [*, N_res, H, 2 * C_hidden]
v_pts
=
self
.
linear_v_points
(
s
,
r
)
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
kv_pts
=
self
.
linear_kv_points
(
s
,
r
)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
))
# [*, N_res, H, P_q/P_v, 3]
k_pts
,
v_pts
=
torch
.
split
(
kv_pts
,
[
self
.
no_qk_points
,
self
.
no_v_points
],
dim
=-
2
)
##########################
##########################
# Compute attention scores
# Compute attention scores
...
@@ -324,12 +363,14 @@ class InvariantPointAttention(nn.Module):
...
@@ -324,12 +363,14 @@ class InvariantPointAttention(nn.Module):
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
*=
math
.
sqrt
(
1.0
/
(
3
*
self
.
c_hidden
))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
a
+=
(
math
.
sqrt
(
1.0
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)))
for
c
in
q_pts
:
print
(
type
(
c
))
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
q_pts
[...,
None
,
:,
:]
-
k_pts
[...,
None
,
:,
:,
:]
pt_att
=
pt_att
*
pt_att
+
self
.
eps
# [*, N_res, N_res, H, P_q]
# [*, N_res, N_res, H, P_q]
pt_att
=
sum
(
torch
.
unbind
(
pt_att
,
dim
=-
1
))
pt_att
=
sum
(
[
c
**
2
for
c
in
pt_att
])
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
)
...
@@ -364,9 +405,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -364,9 +405,7 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# happens in float32.
# [*, N_res, H, P_v]
# [*, N_res, H, P_v]
o_pt
=
v_pts
.
tensor_dot
(
o_pt
=
v_pts
*
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
permute_final_dims
(
a
,
(
1
,
2
,
0
)).
unsqueeze
(
-
1
)
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
o_pt
=
o_pt
.
sum
(
dim
=-
3
)
# [*, N_res, H, P_v]
# [*, N_res, H, P_v]
...
@@ -493,6 +532,7 @@ class StructureModule(nn.Module):
...
@@ -493,6 +532,7 @@ class StructureModule(nn.Module):
trans_scale_factor
,
trans_scale_factor
,
epsilon
,
epsilon
,
inf
,
inf
,
is_multimer
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -546,6 +586,7 @@ class StructureModule(nn.Module):
...
@@ -546,6 +586,7 @@ class StructureModule(nn.Module):
self
.
trans_scale_factor
=
trans_scale_factor
self
.
trans_scale_factor
=
trans_scale_factor
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
inf
=
inf
self
.
inf
=
inf
self
.
is_multimer
=
is_multimer
# To be lazily initialized later
# To be lazily initialized later
self
.
default_frames
=
None
self
.
default_frames
=
None
...
@@ -567,6 +608,7 @@ class StructureModule(nn.Module):
...
@@ -567,6 +608,7 @@ class StructureModule(nn.Module):
self
.
no_v_points
,
self
.
no_v_points
,
inf
=
self
.
inf
,
inf
=
self
.
inf
,
eps
=
self
.
epsilon
,
eps
=
self
.
epsilon
,
is_multimer
=
self
.
is_multimer
,
)
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
self
.
ipa_dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
...
@@ -587,27 +629,62 @@ class StructureModule(nn.Module):
...
@@ -587,27 +629,62 @@ class StructureModule(nn.Module):
self
.
no_angles
,
self
.
no_angles
,
self
.
epsilon
,
self
.
epsilon
,
)
)
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
if
self
.
default_frames
is
None
:
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
group_idx
is
None
:
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
atom_mask
is
None
:
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
float_dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
self
,
r
,
f
# [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
().
device
)
return
frames_and_literature_positions_to_atom14_pos
(
r
,
f
,
self
.
default_frames
,
self
.
group_idx
,
self
.
atom_mask
,
self
.
lit_positions
,
)
def
forward
(
def
_forward_monomer
(
self
,
self
,
s
,
s
,
z
,
z
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
):
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if
mask
is
None
:
if
mask
is
None
:
# [*, N]
# [*, N]
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
...
@@ -690,51 +767,97 @@ class StructureModule(nn.Module):
...
@@ -690,51 +767,97 @@ class StructureModule(nn.Module):
return
outputs
return
outputs
def
_init_residue_constants
(
self
,
float_dtype
,
device
):
def
_forward_multimer
(
self
,
if
self
.
default_frames
is
None
:
s
,
self
.
default_frames
=
torch
.
tensor
(
z
,
restype_rigid_group_default_frame
,
aatype
,
dtype
=
float_dtype
,
mask
=
None
,
device
=
device
,
):
requires_grad
=
False
,
if
mask
is
None
:
)
# [*, N]
if
self
.
group_idx
is
None
:
mask
=
s
.
new_ones
(
s
.
shape
[:
-
1
])
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
# [*, N, C_s]
device
=
device
,
s
=
self
.
layer_norm_s
(
s
)
requires_grad
=
False
,
)
# [*, N, N, C_z]
if
self
.
atom_mask
is
None
:
z
=
self
.
layer_norm_z
(
z
)
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
# [*, N, C_s]
dtype
=
float_dtype
,
s_initial
=
s
device
=
device
,
s
=
self
.
linear_in
(
s
)
requires_grad
=
False
,
# [*, N]
rigids
=
Rigid3Array
.
identity
(
s
.
shape
[:
-
1
],
s
.
device
,
)
outputs
=
[]
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
self
.
ipa_dropout
(
s
)
s
=
self
.
layer_norm_ipa
(
s
)
s
=
self
.
transition
(
s
)
# [*, N]
rigids
=
rigids
@
self
.
bb_update
(
s
)
# [*, N, 7, 2]
unnormalized_angles
,
angles
=
self
.
angle_resnet
(
s
,
s_initial
)
all_frames_to_global
=
self
.
torsion_angles_to_frames
(
rigids
.
scale_translation
(
self
.
trans_scale_factor
),
angles
,
aatype
,
)
)
if
self
.
lit_positions
is
None
:
self
.
lit_positions
=
torch
.
tensor
(
pred_xyz
=
self
.
frames_and_literature_positions_to_atom14_pos
(
restype_atom14_rigid_group_positions
,
all_frames_to_global
,
dtype
=
float_dtype
,
aatype
,
device
=
device
,
requires_grad
=
False
,
)
)
preds
=
{
"frames"
:
rigids
.
scale_translation
(
self
.
trans_scale_factor
).
to_tensor7
(),
"sidechain_frames"
:
all_frames_to_global
.
to_tensor_4x4
(),
"unnormalized_angles"
:
unnormalized_angles
,
"angles"
:
angles
,
"positions"
:
pred_xyz
,
}
def
torsion_angles_to_frames
(
self
,
r
,
alpha
,
f
):
outputs
.
append
(
preds
)
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
alpha
.
dtype
,
alpha
.
device
)
# Separated purely to make testing less annoying
return
torsion_angles_to_frames
(
r
,
alpha
,
f
,
self
.
default_frames
)
def
frames_and_literature_positions_to_atom14_pos
(
if
i
<
(
self
.
no_blocks
-
1
):
self
,
r
,
f
# [*, N, 8] # [*, N]
rigids
=
rigids
.
stop_rot_gradient
()
outputs
=
dict_multimap
(
torch
.
stack
,
outputs
)
outputs
[
"single"
]
=
s
return
outputs
def
forward
(
self
,
s
,
z
,
aatype
,
mask
=
None
,
):
):
# Lazily initialize the residue constants on the correct device
"""
self
.
_init_residue_constants
(
r
.
get_rots
().
dtype
,
r
.
get_rots
().
device
)
Args:
return
frames_and_literature_positions_to_atom14_pos
(
s:
r
,
[*, N_res, C_s] single representation
f
,
z:
self
.
default_frames
,
[*, N_res, N_res, C_z] pair representation
self
.
group_idx
,
aatype:
self
.
atom_mask
,
[*, N_res] amino acid indices
self
.
lit_positions
,
mask:
)
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if
(
self
.
is_multimer
):
outputs
=
self
.
_forward_multimer
(
s
,
z
,
aatype
,
mask
)
else
:
outputs
=
self
.
_forward_monomer
(
s
,
z
,
aatype
,
mask
)
return
outputs
openfold/np/protein.py
View file @
4bd1b4d5
...
@@ -62,7 +62,7 @@ class Protein:
...
@@ -62,7 +62,7 @@ class Protein:
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
b_factors
:
np
.
ndarray
# [num_res, num_atom_type]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
:
if
(
len
(
np
.
unique
(
self
.
chain_index
))
>
PDB_MAX_CHAINS
)
:
raise
ValueError
(
raise
ValueError
(
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
f
"Cannot build an instance with more than
{
PDB_MAX_CHAINS
}
"
"chains because these cannot be written to PDB format"
"chains because these cannot be written to PDB format"
...
...
openfold/utils/all_atom_multimer.py
0 → 100644
View file @
4bd1b4d5
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from
functools
import
partial
from
typing
import
Dict
,
Text
,
Tuple
import
torch
from
openfold.np
import
residue_constants
as
rc
from
openfold.utils
import
geometry
,
tensor_utils
import
numpy
as
np
def
squared_difference
(
x
,
y
):
return
jnp
.
square
(
x
-
y
)
def
get_rc_tensor
(
rc_np
,
aatype
):
return
torch
.
tensor
(
rc_np
,
device
=
aatype
.
device
)[
aatype
]
def
atom14_to_atom37
(
atom14_data
:
torch
.
Tensor
,
# (*, N, 14, ...)
aatype
:
torch
.
Tensor
# (*, N)
)
->
torch
.
Tensor
:
# (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_TO_ATOM14
,
aatype
)
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
atom37_data
=
tensor_utils
.
batched_gather
(
atom14_data
,
idx_atom37_to_atom14
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
)
atom37_mask
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM37_MASK
,
aatype
)
if
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
2
:
atom37_data
*=
atom37_mask
elif
len
(
atom14_data
.
shape
)
==
no_batch_dims
+
3
:
atom37_data
*=
atom37_mask
[...,
None
].
astype
(
atom37_data
.
dtype
)
else
:
raise
ValueError
(
"Incorrectly shaped data"
)
return
atom37_data
def
atom37_to_atom14
(
aatype
,
all_atom_pos
,
all_atom_mask
):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37
=
get_rc_tensor
(
rc
.
RESTYPE_ATOM14_TO_ATOM37
,
aatype
)
no_batch_dims
=
len
(
aatype
.
shape
)
atom14_mask
=
tensor_utils
.
batched_gather
(
all_atom_mask
,
residx_atom14_to_atom37
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
,
).
to
(
torch
.
float32
)
# create a mask for known groundtruth positions
atom14_mask
*=
get_rc_tensor
(
rc
.
RESTYPE_ATOM14_MASK
,
aatype
)
# gather the groundtruth positions
atom14_positions
=
tensor_utils
.
batched_gather
(
all_atom_pos
,
residx_atom14_to_atom37
,
dim
=
no_batch_dims
+
1
,
no_batch_dims
=
no_batch_dims
+
1
,
),
atom14_positions
=
atom14_mask
*
atom14_positions
return
atom14_positions
,
atom14_mask
def
get_alt_atom14
(
aatype
,
positions
:
torch
.
Tensor
,
mask
):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform
=
get_rc_tensor
(
rc
.
RENAMING_MATRICES
,
aatype
)
alternative_positions
=
torch
.
sum
(
positions
[...,
None
,
:]
*
renaming_transform
[...,
None
],
dim
=-
2
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask
=
torch
.
sum
(
mask
[...,
None
]
*
renaming_transform
,
dim
=-
2
)
return
alternative_positions
,
alternative_mask
def
atom37_to_frames
(
aatype
:
torch
.
Tensor
,
# (...)
all_atom_positions
:
torch
.
Tensor
,
# (..., 37)
all_atom_mask
:
torch
.
Tensor
,
# (..., 37)
)
->
Dict
[
Text
,
torch
.
Tensor
]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx
=
get_rc_tensor
(
rc
.
RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX
,
aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos
=
tensor_utils
.
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=
no_batch_dims
+
1
,
batch_dims
=
no_batch_dims
+
1
,
)
# Compute the Rigids.
point_on_neg_x_axis
=
base_atom_pos
[...,
:,
:,
0
]
origin
=
base_atom_pos
[...,
:,
:,
1
]
point_on_xy_plane
=
base_atom_pos
[...,
:,
:,
2
]
gt_rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
origin
-
point_on_neg_x_axis
,
point_on_xy_plane
-
origin
)
gt_frames
=
geometry
.
Rigid3Array
(
gt_rotation
,
origin
)
# Compute a mask whether the group exists.
# (N, 8)
group_exists
=
get_rc_tensor
(
rc
.
RESTYPE_RIGIDGROUP_MASK
,
aatype
)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist
=
tensor_utils
.
batched_gather
(
# shape (N, 8, 3)
all_atom_mask
.
to
(
dtype
=
torch
.
float32
),
residx_rigidgroup_base_atom37_idx
,
batch_dims
=
no_batch_dims
+
1
,
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)
*
group_exists
# (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
8
,
1
,
1
])
rots
[
0
,
0
,
0
]
=
-
1
rots
[
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose_rotation
(
geometry
.
Rot3Array
.
from_array
(
torch
.
tensor
(
rots
,
device
=
aatype
.
device
)
)
)
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous
=
np
.
zeros
([
21
,
8
],
dtype
=
np
.
float32
)
restype_rigidgroup_rots
=
np
.
tile
(
np
.
eye
(
3
,
dtype
=
np
.
float32
),
[
21
,
8
,
1
,
1
]
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous
=
torch
.
tensor
(
restype_rigidgroup_is_ambiguous
,
device
=
aatype
.
device
,
)[
aatype
]
ambiguity_rot
=
torch
.
tensor
(
restype_rigidgroup_rots
,
device
=
aatype
.
device
,
)[
aatype
]
ambiguity_rot
=
geometry
.
Rot3Array
.
from_array
(
torch
.
Tensor
(
ambiguity_rot
,
device
=
aatype
.
device
)
)
# Create the alternative ground truth frames.
alt_gt_frames
=
gt_frames
.
compose_rotation
(
ambiguity_rot
)
fix_shape
=
lambda
x
:
x
.
reshape
(
x
.
shape
[:
-
2
]
+
(
8
,))
# reshape back to original residue layout
gt_frames
=
fix_shape
(
gt_frames
)
gt_exists
=
fix_shape
(
gt_exists
)
group_exists
=
fix_shape
(
group_exists
)
residx_rigidgroup_is_ambiguous
=
fix_shape
(
residx_rigidgroup_is_ambiguous
)
alt_gt_frames
=
fix_shape
(
alt_gt_frames
)
return
{
'rigidgroups_gt_frames'
:
gt_frames
,
# Rigid (..., 8)
'rigidgroups_gt_exists'
:
gt_exists
,
# (..., 8)
'rigidgroups_group_exists'
:
group_exists
,
# (..., 8)
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
# (..., 8)
'rigidgroups_alt_gt_frames'
:
alt_gt_frames
,
# Rigid (..., 8)
}
def
torsion_angles_to_frames
(
aatype
:
torch
.
Tensor
,
# (N)
backb_to_global
:
geometry
.
Rigid3Array
,
# (N)
torsion_angles_sin_cos
:
torch
.
Tensor
# (N, 7, 2)
)
->
geometry
.
Rigid3Array
:
# (N, 8)
"""Compute rigid group frames from torsion angles."""
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m
=
get_rc_tensor
(
rc
.
restype_rigid_group_default_frame
,
aatype
)
default_frames
=
geometry
.
Rigid3Array
.
from_array4x4
(
m
)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles
=
torsion_angles_sin_cos
[...,
0
]
cos_angles
=
torsion_angles_sin_cos
[...,
1
]
# insert zero rotation for backbone group.
num_residues
=
aatype
.
shape
[
-
1
]
sin_angles
=
torch
.
cat
(
[
torch
.
zeros_like
(
aatype
).
unsqueeze
(),
sin_angles
,
],
dim
=-
1
)
cos_angles
=
torch
.
cat
(
[
torch
.
ones_like
(
aatype
).
unsqueeze
(),
cos_angles
],
dim
=-
1
)
zeros
=
torch
.
zeros_like
(
sin_angles
)
ones
=
torch
.
ones_like
(
sin_angles
)
# all_rots are geometry.Rot3Array with shape (..., N, 8)
all_rots
=
geometry
.
Rot3Array
(
ones
,
zeros
,
zeros
,
zeros
,
cos_angles
,
-
sin_angles
,
zeros
,
sin_angles
,
cos_angles
)
# Apply rotations to the frames.
all_frames
=
default_frames
.
compose_rotation
(
all_rots
)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb
=
all_frames
[...,
4
]
chi2_frame_to_backb
=
chi1_frame_to_backb
@
all_frames
[...,
5
]
chi3_frame_to_backb
=
chi2_frame_to_backb
@
all_frames
[...,
6
]
chi4_frame_to_backb
=
chi3_frame_to_backb
@
all_frames
[...,
7
]
all_frames_to_backb
=
Rigid3Array
.
cat
(
[
all_frames
[...,
0
:
5
],
chi2_frame_to_backb
[...,
None
],
chi3_frame_to_backb
[...,
None
],
chi4_frame_to_backb
[...,
None
]
],
dim
=-
1
)
# Create the global frames.
# shape (N, 8)
all_frames_to_global
=
backb_to_global
[...,
None
]
@
all_frames_to_backb
return
all_frames_to_global
def
frames_and_literature_positions_to_atom14_pos
(
aatype
:
torch
.
Tensor
,
# (*, N)
all_frames_to_global
:
geometry
.
Rigid3Array
# (N, 8)
)
->
geometry
.
Vec3Array
:
# (*, N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx
=
get_rc_tensor
(
rc
.
restype_atom14_to_rigid_group
,
aatype
)
group_mask
=
torch
.
nn
.
functional
.
one_hot
(
residx_to_group_idx
,
num_classes
=
8
)
# shape (*, N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global
=
all_frames_to_global
[...,
None
,
:]
*
group_mask
map_atoms_to_global
=
map_atoms_to_global
.
map_tensor_fn
(
partial
(
torch
.
sum
,
dim
=-
1
)
)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions
=
geometry
.
Vec3Array
.
from_array
(
get_rc_tensor
(
rc
.
restype_atom14_rigid_group_positions
,
aatype
)
)
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions
=
map_atoms_to_global
.
apply_to_point
(
lit_positions
)
# Mask out non-existing atoms.
mask
=
get_rc_tensor
(
rc
.
restype_atom14_mask
,
aatype
)
pred_positions
=
pred_positions
*
mask
return
pred_positions
def
extreme_ca_ca_distance_violations
(
positions
:
geometry
.
Vec3Array
,
# (N, 37(14))
mask
:
torch
.
Tensor
,
# (N, 37(14))
residue_index
:
torch
.
Tensor
,
# (N)
max_angstrom_tolerance
=
1.5
,
eps
:
float
=
1e-6
)
->
torch
.
Tensor
:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos
=
positions
[...,
:
-
1
,
1
]
# (N - 1,)
this_ca_mask
=
mask
[...,
:
-
1
,
1
]
# (N - 1)
next_ca_pos
=
positions
[...,
1
:,
1
]
# (N - 1,)
next_ca_mask
=
mask
[...,
1
:,
1
]
# (N - 1)
has_no_gap_mask
=
(
(
residue_index
[...,
1
:]
-
residue_index
[...,
:
-
1
])
==
1.0
).
astype
(
torch
.
float32
)
ca_ca_distance
=
geometry
.
euclidean_distance
(
this_ca_pos
,
next_ca_pos
,
eps
)
violations
=
(
ca_ca_distance
-
rc
.
ca_ca
)
>
max_angstrom_tolerance
mask
=
this_ca_mask
*
next_ca_mask
*
has_no_gap_mask
return
tensor_utils
.
masked_mean
(
mask
=
mask
,
value
=
violations
,
dim
=-
1
)
def
get_chi_atom_indices
(
device
:
torch
.
device
):
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
]
)
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
torch
.
tensor
(
chi_atom_indices
,
device
=
device
)
def
compute_chi_angles
(
positions
:
geometry
.
Vec3Array
,
mask
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, rc.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, rc.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert
positions
.
shape
[
-
1
]
==
rc
.
atom_type_num
assert
mask
.
shape
[
-
1
]
==
rc
.
atom_type_num
no_batch_dims
=
len
(
aatype
.
shape
)
-
1
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices
=
get_chi_atom_indices
(
aatype
.
device
)
# DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why
# theirs works.
aatype_gapless
=
torch
.
clamp
(
aatype
,
max
=
20
)
# Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4].
atom_indices
=
chi_atom_indices
[
aatype_gapless
]
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms
=
positions
.
map_tensor_fn
(
partial
(
tensor_utils
.
batched_gather
,
inds
=
atom_indices
,
dim
=-
1
,
no_batch_dims
=
no_batch_dims
+
1
)
)
a
,
b
,
c
,
d
=
[
chi_angle_atoms
[...,
i
]
for
i
in
range
(
4
)]
chi_angles
=
geometry
.
dihedral_angle
(
a
,
b
,
c
,
d
)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.0
,
0.0
,
0.0
,
0.0
])
chi_angles_mask
=
torch
.
tensor
(
chi_angles_mask
,
device
=
aatype
.
device
)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask
=
chi_angles_mask
[
aatype_gapless
]
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask
=
tensor_utils
.
batched_gather
(
mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
no_batch_dims
+
1
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
)
chi_mask
=
chi_mask
*
chi_angle_atoms_mask
.
to
(
torch
.
float32
)
return
chi_angles
,
chi_mask
def
make_transform_from_reference
(
a_xyz
:
geometry
.
Vec3Array
,
b_xyz
:
geometry
.
Vec3Array
,
c_xyz
:
geometry
.
Vec3Array
)
->
geometry
.
Rigid3Array
:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation
=
geometry
.
Rot3Array
.
from_two_vectors
(
c_xyz
-
b_xyz
,
a_xyz
-
b_xyz
)
return
geometry
.
Rigid3Array
(
rotation
,
b_xyz
)
def
make_backbone_affine
(
positions
:
geometry
.
Vec3Array
,
mask
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
)
->
Tuple
[
geometry
.
Rigid3Array
,
torch
.
Tensor
]:
a
=
rc
.
atom_order
[
'N'
]
b
=
rc
.
atom_order
[
'CA'
]
c
=
rc
.
atom_order
[
'C'
]
rigid_mask
=
(
mask
[...,
a
]
*
mask
[...,
b
]
*
mask
[...,
c
])
rigid
=
make_transform_from_reference
(
a_xyz
=
positions
[...,
a
],
b_xyz
=
positions
[...,
b
],
c_xyz
=
positions
[...,
c
],
)
return
rigid
,
rigid_mask
openfold/utils/argparse_utils.py
0 → 100644
View file @
4bd1b4d5
from
argparse
import
HelpFormatter
from
operator
import
attrgetter
class
ArgparseAlphabetizer
(
HelpFormatter
):
"""
Sorts the optional arguments of an argparse parser alphabetically
"""
@
staticmethod
def
sort_actions
(
actions
):
return
sorted
(
actions
,
key
=
attrgetter
(
"option_strings"
))
# Formats the help message
def
add_arguments
(
self
,
actions
):
actions
=
ArgparseAlphabetizer
.
sort_actions
(
actions
)
super
(
ArgparseAlphabetizer
,
self
).
add_arguments
(
actions
)
# Formats the usage message
def
add_usage
(
self
,
usage
,
actions
,
groups
,
prefix
=
None
):
actions
=
ArgparseAlphabetizer
.
sort_actions
(
actions
)
args
=
usage
,
actions
,
groups
,
prefix
super
(
ArgparseAlphabetizer
,
self
).
add_usage
(
*
args
)
def
remove_arguments
(
parser
,
args
):
for
arg
in
args
:
for
action
in
parser
.
_actions
:
opts
=
vars
(
action
)[
"option_strings"
]
if
(
arg
in
opts
):
parser
.
_handle_conflict_resolve
(
None
,
[(
arg
,
action
)])
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment