Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
07e64267
Commit
07e64267
authored
Oct 16, 2021
by
Gustaf Ahdritz
Browse files
Standardize code style
parent
de07730f
Changes
60
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2923 additions
and
2530 deletions
+2923
-2530
openfold/config.py
openfold/config.py
+374
-360
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+144
-135
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+442
-374
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+31
-30
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+59
-49
openfold/data/mmcif_parsing.py
openfold/data/mmcif_parsing.py
+360
-320
openfold/data/parsers.py
openfold/data/parsers.py
+81
-57
openfold/data/templates.py
openfold/data/templates.py
+356
-216
openfold/data/tools/hhblits.py
openfold/data/tools/hhblits.py
+146
-125
openfold/data/tools/hhsearch.py
openfold/data/tools/hhsearch.py
+72
-59
openfold/data/tools/jackhmmer.py
openfold/data/tools/jackhmmer.py
+200
-169
openfold/data/tools/kalign.py
openfold/data/tools/kalign.py
+87
-76
openfold/data/tools/utils.py
openfold/data/tools/utils.py
+12
-12
openfold/model/__init__.py
openfold/model/__init__.py
+6
-5
openfold/model/dropout.py
openfold/model/dropout.py
+27
-24
openfold/model/embedders.py
openfold/model/embedders.py
+134
-133
openfold/model/evoformer.py
openfold/model/evoformer.py
+125
-114
openfold/model/heads.py
openfold/model/heads.py
+66
-61
openfold/model/model.py
openfold/model/model.py
+117
-115
openfold/model/msa.py
openfold/model/msa.py
+84
-96
No files found.
openfold/config.py
View file @
07e64267
This diff is collapsed.
Click to expand it.
openfold/data/data_pipeline.py
View file @
07e64267
...
...
@@ -27,45 +27,45 @@ from openfold.np import residue_constants
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
"""Construct a feature dict of sequence features."""
features
=
{}
features
[
'
aatype
'
]
=
residue_constants
.
sequence_to_onehot
(
features
[
"
aatype
"
]
=
residue_constants
.
sequence_to_onehot
(
sequence
=
sequence
,
mapping
=
residue_constants
.
restype_order_with_x
,
map_unknown_to_x
=
True
map_unknown_to_x
=
True
,
)
features
[
'
between_segment_residues
'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'
domain_name
'
]
=
np
.
array
(
[
description
.
encode
(
'
utf-8
'
)],
dtype
=
np
.
object_
features
[
"
between_segment_residues
"
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
"
domain_name
"
]
=
np
.
array
(
[
description
.
encode
(
"
utf-8
"
)],
dtype
=
np
.
object_
)
features
[
'
residue_index
'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'
seq_length
'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'
sequence
'
]
=
np
.
array
(
[
sequence
.
encode
(
'
utf-8
'
)],
dtype
=
np
.
object_
features
[
"
residue_index
"
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
"
seq_length
"
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
"
sequence
"
]
=
np
.
array
(
[
sequence
.
encode
(
"
utf-8
"
)],
dtype
=
np
.
object_
)
return
features
def
make_mmcif_features
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
)
->
FeatureDict
:
input_sequence
=
mmcif_object
.
chain_to_seqres
[
chain_id
]
description
=
'_'
.
join
([
mmcif_object
.
file_id
,
chain_id
])
description
=
"_"
.
join
([
mmcif_object
.
file_id
,
chain_id
])
num_res
=
len
(
input_sequence
)
mmcif_feats
=
{}
mmcif_feats
.
update
(
make_sequence_features
(
mmcif_feats
.
update
(
make_sequence_features
(
sequence
=
input_sequence
,
description
=
description
,
num_res
=
num_res
,
))
)
)
all_atom_positions
,
all_atom_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
chain_id
...
...
@@ -78,7 +78,7 @@ def make_mmcif_features(
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
'
utf-8
'
)],
dtype
=
np
.
object_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"
utf-8
"
)],
dtype
=
np
.
object_
)
return
mmcif_feats
...
...
@@ -86,17 +86,20 @@ def make_mmcif_features(
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
],
)
->
FeatureDict
:
"""Constructs a feature dict of MSA features."""
if
not
msas
:
raise
ValueError
(
'
At least one MSA must be provided.
'
)
raise
ValueError
(
"
At least one MSA must be provided.
"
)
int_msa
=
[]
deletion_matrix
=
[]
seen_sequences
=
set
()
for
msa_index
,
msa
in
enumerate
(
msas
):
if
not
msa
:
raise
ValueError
(
f
'MSA
{
msa_index
}
must contain at least one sequence.'
)
raise
ValueError
(
f
"MSA
{
msa_index
}
must contain at least one sequence."
)
for
sequence_index
,
sequence
in
enumerate
(
msa
):
if
sequence
in
seen_sequences
:
continue
...
...
@@ -109,17 +112,19 @@ def make_msa_features(
num_res
=
len
(
msas
[
0
][
0
])
num_alignments
=
len
(
int_msa
)
features
=
{}
features
[
'
deletion_matrix_int
'
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
features
[
'
msa
'
]
=
np
.
array
(
int_msa
,
dtype
=
np
.
int32
)
features
[
'
num_alignments
'
]
=
np
.
array
(
features
[
"
deletion_matrix_int
"
]
=
np
.
array
(
deletion_matrix
,
dtype
=
np
.
int32
)
features
[
"
msa
"
]
=
np
.
array
(
int_msa
,
dtype
=
np
.
int32
)
features
[
"
num_alignments
"
]
=
np
.
array
(
[
num_alignments
]
*
num_res
,
dtype
=
np
.
int32
)
return
features
class
AlignmentRunner
:
""" Runs alignment tools and saves the results """
def
__init__
(
self
,
"""Runs alignment tools and saves the results"""
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
hhsearch_binary_path
:
str
,
...
...
@@ -161,105 +166,109 @@ class AlignmentRunner:
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
)
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
def
run
(
self
,
def
run
(
self
,
fasta_path
:
str
,
output_dir
:
str
,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
'
sto
'
],
max_sequences
=
self
.
uniref_max_hits
jackhmmer_uniref90_result
[
"
sto
"
],
max_sequences
=
self
.
uniref_max_hits
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
'
uniref90_hits.a3m
'
)
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"
uniref90_hits.a3m
"
)
with
open
(
uniref90_out_path
,
"w"
)
as
f
:
f
.
write
(
uniref90_msa_as_a3m
)
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_mgnify_result
[
'
sto
'
],
max_sequences
=
self
.
mgnify_max_hits
jackhmmer_mgnify_result
[
"
sto
"
],
max_sequences
=
self
.
mgnify_max_hits
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
'
mgnify_hits.a3m
'
)
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"
mgnify_hits.a3m
"
)
with
open
(
mgnify_out_path
,
"w"
)
as
f
:
f
.
write
(
mgnify_msa_as_a3m
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
'
pdb70_hits.hhr
'
)
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
"
pdb70_hits.hhr
"
)
with
open
(
pdb70_out_path
,
"w"
)
as
f
:
f
.
write
(
hhsearch_result
)
if
self
.
_use_small_bfd
:
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'small_bfd_hits.sto'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"small_bfd_hits.sto"
)
with
open
(
bfd_out_path
,
"w"
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
"sto"
])
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
if
(
output_dir
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
hhblits_bfd_uniclust_result
=
(
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
)
if
output_dir
is
not
None
:
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
"bfd_uniclust_hits.a3m"
)
with
open
(
bfd_out_path
,
"w"
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
"a3m"
])
class
DataPipeline
:
"""Assembles input features."""
def
__init__
(
self
,
def
__init__
(
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
):
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
def
_parse_alignment_output
(
self
,
def
_parse_alignment_output
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
'uniref90_hits.a3m'
)
with
open
(
uniref90_out_path
,
'r'
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
"uniref90_hits.a3m"
)
with
open
(
uniref90_out_path
,
"r"
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'r'
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
"mgnify_hits.a3m"
)
with
open
(
mgnify_out_path
,
"r"
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'r'
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
()
)
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
"pdb70_hits.hhr"
)
with
open
(
pdb70_out_path
,
"r"
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
())
if
(
self
.
use_small_bfd
)
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'
small_bfd_hits.sto
'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
if
self
.
use_small_bfd
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"
small_bfd_hits.sto
"
)
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
)
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
"bfd_uniclust_hits.a3m"
)
with
open
(
bfd_out_path
,
"r"
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
())
return
{
'
uniref90_msa
'
:
uniref90_msa
,
'
uniref90_deletion_matrix
'
:
uniref90_deletion_matrix
,
'
mgnify_msa
'
:
mgnify_msa
,
'
mgnify_deletion_matrix
'
:
mgnify_deletion_matrix
,
'
hhsearch_hits
'
:
hhsearch_hits
,
'
bfd_msa
'
:
bfd_msa
,
'
bfd_deletion_matrix
'
:
bfd_deletion_matrix
,
"
uniref90_msa
"
:
uniref90_msa
,
"
uniref90_deletion_matrix
"
:
uniref90_deletion_matrix
,
"
mgnify_msa
"
:
mgnify_msa
,
"
mgnify_deletion_matrix
"
:
mgnify_deletion_matrix
,
"
hhsearch_hits
"
:
hhsearch_hits
,
"
bfd_msa
"
:
bfd_msa
,
"
bfd_deletion_matrix
"
:
bfd_deletion_matrix
,
}
def
process_fasta
(
self
,
def
process_fasta
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
...
...
@@ -269,7 +278,8 @@ class DataPipeline:
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
fasta_path
}
.'
)
f
"More than one input sequence found in
{
fasta_path
}
."
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
...
...
@@ -280,30 +290,31 @@ class DataPipeline:
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
None
,
hits
=
alignments
[
'
hhsearch_hits
'
]
hits
=
alignments
[
"
hhsearch_hits
"
],
)
sequence_features
=
make_sequence_features
(
sequence
=
input_sequence
,
description
=
input_description
,
num_res
=
num_res
num_res
=
num_res
,
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
'
uniref90_msa
'
],
alignments
[
'
bfd_msa
'
],
alignments
[
'
mgnify_msa
'
]
alignments
[
"
uniref90_msa
"
],
alignments
[
"
bfd_msa
"
],
alignments
[
"
mgnify_msa
"
],
),
deletion_matrices
=
(
alignments
[
'
uniref90_deletion_matrix
'
],
alignments
[
'
bfd_deletion_matrix
'
],
alignments
[
'
mgnify_deletion_matrix
'
]
)
alignments
[
"
uniref90_deletion_matrix
"
],
alignments
[
"
bfd_deletion_matrix
"
],
alignments
[
"
mgnify_deletion_matrix
"
],
)
,
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
data
}
def
process_mmcif
(
self
,
def
process_mmcif
(
self
,
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
...
...
@@ -314,13 +325,11 @@ class DataPipeline:
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if
(
chain_id
is
None
)
:
if
chain_id
is
None
:
chains
=
mmcif
.
structure
.
get_chains
()
chain
=
next
(
chains
,
None
)
if
(
chain
is
None
):
raise
ValueError
(
'No chains in mmCIF file'
)
if
chain
is
None
:
raise
ValueError
(
"No chains in mmCIF file"
)
chain_id
=
chain
.
id
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
...
...
@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
'
hhsearch_hits
'
]
hits
=
alignments
[
"
hhsearch_hits
"
],
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
'bfd_msa'
],
alignments
[
'mgnify_msa'
]
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
return
{
**
mmcif_feats
,
**
templates_result
.
data
,
**
msa_features
}
openfold/data/data_transforms.py
View file @
07e64267
This diff is collapsed.
Click to expand it.
openfold/data/feature_pipeline.py
View file @
07e64267
...
...
@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
TensorDict
=
Dict
[
str
,
torch
.
Tensor
]
def
np_to_tensor_dict
(
np_example
:
Mapping
[
str
,
np
.
ndarray
],
features
:
Sequence
[
str
],
)
->
TensorDict
:
)
->
TensorDict
:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
...
...
@@ -54,7 +55,7 @@ def make_data_config(
cfg
=
copy
.
deepcopy
(
config
)
mode_cfg
=
cfg
[
mode
]
with
cfg
.
unlocked
():
if
(
mode_cfg
.
crop_size
is
None
)
:
if
mode_cfg
.
crop_size
is
None
:
mode_cfg
.
crop_size
=
num_res
feature_names
=
cfg
.
common
.
unsupervised_features
...
...
@@ -62,7 +63,7 @@ def make_data_config(
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
if
(
cfg
[
mode
].
supervised
)
:
if
cfg
[
mode
].
supervised
:
feature_names
+=
cfg
.
common
.
supervised_features
return
cfg
,
feature_names
...
...
@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode
:
str
,
):
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
num_res
=
int
(
np_example
[
"seq_length"
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
if
'
deletion_matrix_int
'
in
np_example
:
np_example
[
'
deletion_matrix
'
]
=
(
np_example
.
pop
(
'
deletion_matrix_int
'
).
astype
(
np
.
float32
)
)
if
"
deletion_matrix_int
"
in
np_example
:
np_example
[
"
deletion_matrix
"
]
=
np_example
.
pop
(
"
deletion_matrix_int
"
)
.
astype
(
np
.
float32
)
if
batch_mode
==
'clamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
1.
).
astype
(
np
.
float32
)
)
elif
batch_mode
==
'unclamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
0.
).
astype
(
np
.
float32
)
)
if
batch_mode
==
"clamped"
:
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
1.0
).
astype
(
np
.
float32
)
elif
batch_mode
==
"unclamped"
:
np_example
[
"use_clamped_fape"
]
=
np
.
array
(
0.0
).
astype
(
np
.
float32
)
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
class
FeaturePipeline
:
def
__init__
(
self
,
def
__init__
(
self
,
config
:
ml_collections
.
ConfigDict
,
params
:
Optional
[
Mapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]]
=
None
):
params
:
Optional
[
Mapping
[
str
,
Mapping
[
str
,
np
.
ndarray
]]]
=
None
,
):
self
.
config
=
config
self
.
params
=
params
def
process_features
(
self
,
def
process_features
(
self
,
raw_features
:
FeatureDict
,
mode
:
str
=
'
train
'
,
batch_mode
:
str
=
'
clamped
'
,
mode
:
str
=
"
train
"
,
batch_mode
:
str
=
"
clamped
"
,
)
->
FeatureDict
:
return
np_example_to_features
(
np_example
=
raw_features
,
...
...
openfold/data/input_pipeline.py
View file @
07e64267
...
...
@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms
.
make_hhblits_profile
,
]
if
common_cfg
.
use_templates
:
transforms
.
extend
([
transforms
.
extend
(
[
data_transforms
.
fix_templates_aatype
,
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
])
if
(
common_cfg
.
use_template_torsion_angles
):
transforms
.
extend
([
data_transforms
.
atom37_to_torsion_angles
(
'template_'
),
])
transforms
.
extend
([
data_transforms
.
make_pseudo_beta
(
"template_"
),
]
)
if
common_cfg
.
use_template_torsion_angles
:
transforms
.
extend
(
[
data_transforms
.
atom37_to_torsion_angles
(
"template_"
),
]
)
transforms
.
extend
(
[
data_transforms
.
make_atom14_masks
,
])
]
)
if
(
mode_cfg
.
supervised
):
transforms
.
extend
([
if
mode_cfg
.
supervised
:
transforms
.
extend
(
[
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
''
),
data_transforms
.
make_pseudo_beta
(
''
),
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
make_pseudo_beta
(
""
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
])
]
)
return
transforms
...
...
@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms
.
sample_msa
(
max_msa_clusters
,
keep_extra
=
True
)
)
if
'
masked_msa
'
in
common_cfg
:
if
"
masked_msa
"
in
common_cfg
:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms
.
append
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
)
)
...
...
@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if
mode_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
seed
=
torch
.
Generator
().
seed
()
))
transforms
.
append
(
data_transforms
.
make_fixed_size
(
seed
=
torch
.
Generator
().
seed
(),
)
)
transforms
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
))
mode_cfg
.
max_templates
,
)
)
else
:
transforms
.
append
(
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
...
...
@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
'
clamped
'
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
"
clamped
"
):
"""Based on the config, apply filters and transformations to the data."""
...
...
@@ -136,12 +147,10 @@ def process_tensors_from_config(
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
)
fn
=
compose
(
fns
)
d
[
'
ensemble_index
'
]
=
i
d
[
"
ensemble_index
"
]
=
i
return
fn
(
d
)
tensors
=
compose
(
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
)
)(
tensors
)
tensors
=
compose
(
nonensembled_transform_fns
(
common_cfg
,
mode_cfg
))(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
num_ensemble
=
mode_cfg
.
num_ensemble
...
...
@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble
*=
common_cfg
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_ensemble
))
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_ensemble
)
)
else
:
tensors
=
tree
.
map_structure
(
lambda
x
:
x
[
None
],
tensors_0
)
...
...
openfold/data/mmcif_parsing.py
View file @
07e64267
...
...
@@ -90,6 +90,7 @@ class MmcifObject:
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id
:
str
header
:
PdbHeader
structure
:
PdbStructure
...
...
@@ -107,6 +108,7 @@ class ParsingResult:
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object
:
Optional
[
MmcifObject
]
errors
:
Mapping
[
Tuple
[
str
,
str
],
Any
]
...
...
@@ -115,8 +117,9 @@ class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def
mmcif_loop_to_list
(
prefix
:
str
,
parsed_info
:
MmCIFDict
)
->
Sequence
[
Mapping
[
str
,
str
]]:
def
mmcif_loop_to_list
(
prefix
:
str
,
parsed_info
:
MmCIFDict
)
->
Sequence
[
Mapping
[
str
,
str
]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
...
...
@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str,
data
.
append
(
value
)
assert
all
([
len
(
xs
)
==
len
(
data
[
0
])
for
xs
in
data
]),
(
'mmCIF error: Not all loops are the same length: %s'
%
cols
)
"mmCIF error: Not all loops are the same length: %s"
%
cols
)
return
[
dict
(
zip
(
cols
,
xs
))
for
xs
in
zip
(
*
data
)]
def
mmcif_loop_to_dict
(
prefix
:
str
,
def
mmcif_loop_to_dict
(
prefix
:
str
,
index
:
str
,
parsed_info
:
MmCIFDict
,
)
->
Mapping
[
str
,
Mapping
[
str
,
str
]]:
)
->
Mapping
[
str
,
Mapping
[
str
,
str
]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
...
...
@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str,
return
{
entry
[
index
]:
entry
for
entry
in
entries
}
def
parse
(
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
def
parse
(
*
,
file_id
:
str
,
mmcif_string
:
str
,
catch_all_errors
:
bool
=
True
)
->
ParsingResult
:
"""Entry point, parses an mmcif_string.
Args:
...
...
@@ -188,7 +192,7 @@ def parse(*,
try
:
parser
=
PDB
.
MMCIFParser
(
QUIET
=
True
)
handle
=
io
.
StringIO
(
mmcif_string
)
full_structure
=
parser
.
get_structure
(
''
,
handle
)
full_structure
=
parser
.
get_structure
(
""
,
handle
)
first_model_structure
=
_get_first_model
(
full_structure
)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
...
...
@@ -206,9 +210,12 @@ def parse(*,
valid_chains
=
_get_protein_chains
(
parsed_info
=
parsed_info
)
if
not
valid_chains
:
return
ParsingResult
(
None
,
{(
file_id
,
''
):
'No protein chains found in this file.'
})
seq_start_num
=
{
chain_id
:
min
([
monomer
.
num
for
monomer
in
seq
])
for
chain_id
,
seq
in
valid_chains
.
items
()}
None
,
{(
file_id
,
""
):
"No protein chains found in this file."
}
)
seq_start_num
=
{
chain_id
:
min
([
monomer
.
num
for
monomer
in
seq
])
for
chain_id
,
seq
in
valid_chains
.
items
()
}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
...
...
@@ -217,34 +224,42 @@ def parse(*,
mmcif_to_author_chain_id
=
{}
seq_to_structure_mappings
=
{}
for
atom
in
_get_atom_site_list
(
parsed_info
):
if
atom
.
model_num
!=
'1'
:
if
atom
.
model_num
!=
"1"
:
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id
[
atom
.
mmcif_chain_id
]
=
atom
.
author_chain_id
if
atom
.
mmcif_chain_id
in
valid_chains
:
hetflag
=
' '
if
atom
.
hetatm_atom
==
'
HETATM
'
:
hetflag
=
" "
if
atom
.
hetatm_atom
==
"
HETATM
"
:
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if
atom
.
residue_name
in
(
'
HOH
'
,
'
WAT
'
):
hetflag
=
'W'
if
atom
.
residue_name
in
(
"
HOH
"
,
"
WAT
"
):
hetflag
=
"W"
else
:
hetflag
=
'
H_
'
+
atom
.
residue_name
hetflag
=
"
H_
"
+
atom
.
residue_name
insertion_code
=
atom
.
insertion_code
if
not
_is_set
(
atom
.
insertion_code
):
insertion_code
=
' '
position
=
ResiduePosition
(
chain_id
=
atom
.
author_chain_id
,
insertion_code
=
" "
position
=
ResiduePosition
(
chain_id
=
atom
.
author_chain_id
,
residue_number
=
int
(
atom
.
author_seq_num
),
insertion_code
=
insertion_code
)
seq_idx
=
int
(
atom
.
mmcif_seq_num
)
-
seq_start_num
[
atom
.
mmcif_chain_id
]
current
=
seq_to_structure_mappings
.
get
(
atom
.
author_chain_id
,
{})
current
[
seq_idx
]
=
ResidueAtPosition
(
position
=
position
,
insertion_code
=
insertion_code
,
)
seq_idx
=
(
int
(
atom
.
mmcif_seq_num
)
-
seq_start_num
[
atom
.
mmcif_chain_id
]
)
current
=
seq_to_structure_mappings
.
get
(
atom
.
author_chain_id
,
{}
)
current
[
seq_idx
]
=
ResidueAtPosition
(
position
=
position
,
name
=
atom
.
residue_name
,
is_missing
=
False
,
hetflag
=
hetflag
)
hetflag
=
hetflag
,
)
seq_to_structure_mappings
[
atom
.
author_chain_id
]
=
current
# Add missing residue information to seq_to_structure_mappings.
...
...
@@ -253,19 +268,21 @@ def parse(*,
current_mapping
=
seq_to_structure_mappings
[
author_chain
]
for
idx
,
monomer
in
enumerate
(
seq_info
):
if
idx
not
in
current_mapping
:
current_mapping
[
idx
]
=
ResidueAtPosition
(
position
=
None
,
current_mapping
[
idx
]
=
ResidueAtPosition
(
position
=
None
,
name
=
monomer
.
id
,
is_missing
=
True
,
hetflag
=
' '
)
hetflag
=
" "
,
)
author_chain_to_sequence
=
{}
for
chain_id
,
seq_info
in
valid_chains
.
items
():
author_chain
=
mmcif_to_author_chain_id
[
chain_id
]
seq
=
[]
for
monomer
in
seq_info
:
code
=
SCOPData
.
protein_letters_3to1
.
get
(
monomer
.
id
,
'X'
)
seq
.
append
(
code
if
len
(
code
)
==
1
else
'X'
)
seq
=
''
.
join
(
seq
)
code
=
SCOPData
.
protein_letters_3to1
.
get
(
monomer
.
id
,
"X"
)
seq
.
append
(
code
if
len
(
code
)
==
1
else
"X"
)
seq
=
""
.
join
(
seq
)
author_chain_to_sequence
[
author_chain
]
=
seq
mmcif_object
=
MmcifObject
(
...
...
@@ -274,11 +291,12 @@ def parse(*,
structure
=
first_model_structure
,
chain_to_seqres
=
author_chain_to_sequence
,
seqres_to_structure
=
seq_to_structure_mappings
,
raw_string
=
parsed_info
)
raw_string
=
parsed_info
,
)
return
ParsingResult
(
mmcif_object
=
mmcif_object
,
errors
=
errors
)
except
Exception
as
e
:
# pylint:disable=broad-except
errors
[(
file_id
,
''
)]
=
e
errors
[(
file_id
,
""
)]
=
e
if
not
catch_all_errors
:
raise
return
ParsingResult
(
mmcif_object
=
None
,
errors
=
errors
)
...
...
@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return
next
(
structure
.
get_models
())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE
=
21
def
get_release_date
(
parsed_info
:
MmCIFDict
)
->
str
:
"""Returns the oldest revision date."""
revision_dates
=
parsed_info
[
'
_pdbx_audit_revision_history.revision_date
'
]
revision_dates
=
parsed_info
[
"
_pdbx_audit_revision_history.revision_date
"
]
return
min
(
revision_dates
)
...
...
@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header
=
{}
experiments
=
mmcif_loop_to_list
(
'_exptl.'
,
parsed_info
)
header
[
'structure_method'
]
=
','
.
join
([
experiment
[
'_exptl.method'
].
lower
()
for
experiment
in
experiments
])
experiments
=
mmcif_loop_to_list
(
"_exptl."
,
parsed_info
)
header
[
"structure_method"
]
=
","
.
join
(
[
experiment
[
"_exptl.method"
].
lower
()
for
experiment
in
experiments
]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if
'
_pdbx_audit_revision_history.revision_date
'
in
parsed_info
:
header
[
'
release_date
'
]
=
get_release_date
(
parsed_info
)
if
"
_pdbx_audit_revision_history.revision_date
"
in
parsed_info
:
header
[
"
release_date
"
]
=
get_release_date
(
parsed_info
)
else
:
logging
.
warning
(
'Could not determine release_date: %s'
,
parsed_info
[
'_entry.id'
])
logging
.
warning
(
"Could not determine release_date: %s"
,
parsed_info
[
"_entry.id"
]
)
header
[
'resolution'
]
=
0.00
for
res_key
in
(
'_refine.ls_d_res_high'
,
'_em_3d_reconstruction.resolution'
,
'_reflns.d_resolution_high'
):
header
[
"resolution"
]
=
0.00
for
res_key
in
(
"_refine.ls_d_res_high"
,
"_em_3d_reconstruction.resolution"
,
"_reflns.d_resolution_high"
,
):
if
res_key
in
parsed_info
:
try
:
raw_resolution
=
parsed_info
[
res_key
][
0
]
header
[
'
resolution
'
]
=
float
(
raw_resolution
)
header
[
"
resolution
"
]
=
float
(
raw_resolution
)
except
ValueError
:
logging
.
warning
(
'Invalid resolution format: %s'
,
parsed_info
[
res_key
])
logging
.
warning
(
"Invalid resolution format: %s"
,
parsed_info
[
res_key
]
)
return
header
def
_get_atom_site_list
(
parsed_info
:
MmCIFDict
)
->
Sequence
[
AtomSite
]:
"""Returns list of atom sites; contains data not present in the structure."""
return
[
AtomSite
(
*
site
)
for
site
in
zip
(
# pylint:disable=g-complex-comprehension
parsed_info
[
'_atom_site.label_comp_id'
],
parsed_info
[
'_atom_site.auth_asym_id'
],
parsed_info
[
'_atom_site.label_asym_id'
],
parsed_info
[
'_atom_site.auth_seq_id'
],
parsed_info
[
'_atom_site.label_seq_id'
],
parsed_info
[
'_atom_site.pdbx_PDB_ins_code'
],
parsed_info
[
'_atom_site.group_PDB'
],
parsed_info
[
'_atom_site.pdbx_PDB_model_num'
],
)]
return
[
AtomSite
(
*
site
)
for
site
in
zip
(
# pylint:disable=g-complex-comprehension
parsed_info
[
"_atom_site.label_comp_id"
],
parsed_info
[
"_atom_site.auth_asym_id"
],
parsed_info
[
"_atom_site.label_asym_id"
],
parsed_info
[
"_atom_site.auth_seq_id"
],
parsed_info
[
"_atom_site.label_seq_id"
],
parsed_info
[
"_atom_site.pdbx_PDB_ins_code"
],
parsed_info
[
"_atom_site.group_PDB"
],
parsed_info
[
"_atom_site.pdbx_PDB_model_num"
],
)
]
def
_get_protein_chains
(
*
,
parsed_info
:
Mapping
[
str
,
Any
])
->
Mapping
[
ChainId
,
Sequence
[
Monomer
]]:
*
,
parsed_info
:
Mapping
[
str
,
Any
]
)
->
Mapping
[
ChainId
,
Sequence
[
Monomer
]]:
"""Extracts polymer information for protein chains only.
Args:
...
...
@@ -351,26 +381,29 @@ def _get_protein_chains(
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs
=
mmcif_loop_to_list
(
'
_entity_poly_seq.
'
,
parsed_info
)
entity_poly_seqs
=
mmcif_loop_to_list
(
"
_entity_poly_seq.
"
,
parsed_info
)
polymers
=
collections
.
defaultdict
(
list
)
for
entity_poly_seq
in
entity_poly_seqs
:
polymers
[
entity_poly_seq
[
'_entity_poly_seq.entity_id'
]].
append
(
Monomer
(
id
=
entity_poly_seq
[
'_entity_poly_seq.mon_id'
],
num
=
int
(
entity_poly_seq
[
'_entity_poly_seq.num'
])))
polymers
[
entity_poly_seq
[
"_entity_poly_seq.entity_id"
]].
append
(
Monomer
(
id
=
entity_poly_seq
[
"_entity_poly_seq.mon_id"
],
num
=
int
(
entity_poly_seq
[
"_entity_poly_seq.num"
]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps
=
mmcif_loop_to_dict
(
'
_chem_comp.
'
,
'
_chem_comp.id
'
,
parsed_info
)
chem_comps
=
mmcif_loop_to_dict
(
"
_chem_comp.
"
,
"
_chem_comp.id
"
,
parsed_info
)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms
=
mmcif_loop_to_list
(
'
_struct_asym.
'
,
parsed_info
)
struct_asyms
=
mmcif_loop_to_list
(
"
_struct_asym.
"
,
parsed_info
)
entity_to_mmcif_chains
=
collections
.
defaultdict
(
list
)
for
struct_asym
in
struct_asyms
:
chain_id
=
struct_asym
[
'
_struct_asym.id
'
]
entity_id
=
struct_asym
[
'
_struct_asym.entity_id
'
]
chain_id
=
struct_asym
[
"
_struct_asym.id
"
]
entity_id
=
struct_asym
[
"
_struct_asym.entity_id
"
]
entity_to_mmcif_chains
[
entity_id
].
append
(
chain_id
)
# Identify and return the valid protein chains.
...
...
@@ -379,8 +412,12 @@ def _get_protein_chains(
chain_ids
=
entity_to_mmcif_chains
[
entity_id
]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if
any
([
'peptide'
in
chem_comps
[
monomer
.
id
][
'_chem_comp.type'
]
for
monomer
in
seq_info
]):
if
any
(
[
"peptide"
in
chem_comps
[
monomer
.
id
][
"_chem_comp.type"
]
for
monomer
in
seq_info
]
):
for
chain_id
in
chain_ids
:
valid_chains
[
chain_id
]
=
seq_info
return
valid_chains
...
...
@@ -388,19 +425,18 @@ def _get_protein_chains(
def
_is_set
(
data
:
str
)
->
bool
:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return
data
not
in
(
'.'
,
'?'
)
return
data
not
in
(
"."
,
"?"
)
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
mmcif_object
:
MmcifObject
,
chain_id
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
relevant_chains
=
[
c
for
c
in
chains
if
c
.
id
==
chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'
Expected exactly one chain in structure with id
{
chain_id
}
.
'
f
"
Expected exactly one chain in structure with id
{
chain_id
}
.
"
)
chain
=
relevant_chains
[
0
]
...
...
@@ -417,19 +453,23 @@ def get_atom_coords(
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res
=
chain
[
(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
res_at_position
.
position
.
insertion_code
,
)
]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'
SE
'
and
res
.
get_resname
()
==
'
MSE
'
:
elif
atom_name
.
upper
()
==
"
SE
"
and
res
.
get_resname
()
==
"
MSE
"
:
# Put the coords of the selenium atom in the sulphur column
pos
[
residue_constants
.
atom_order
[
'
SD
'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'
SD
'
]]
=
1.0
pos
[
residue_constants
.
atom_order
[
"
SD
"
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
"
SD
"
]]
=
1.0
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
...
...
@@ -440,22 +480,22 @@ def get_atom_coords(
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'
.cif
'
)
):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
'r'
)
as
fp
:
if
f
.
endswith
(
"
.cif
"
):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
)
:
logging
.
warning
(
f
'
Could not parse
{
f
}
. Skipping...
'
)
if
mmcif
.
mmcif_object
is
None
:
logging
.
warning
(
f
"
Could not parse
{
f
}
. Skipping...
"
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
'
release_date
'
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
'
no_chains
'
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
local_data
[
"
release_date
"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"
no_chains
"
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
'w'
)
as
fp
:
with
open
(
out_path
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
openfold/data/parsers.py
View file @
07e64267
...
...
@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix
=
Sequence
[
Sequence
[
int
]]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TemplateHit
:
"""Class representing a template hit."""
index
:
int
name
:
str
aligned_cols
:
int
...
...
@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index
=
-
1
for
line
in
fasta_string
.
splitlines
():
line
=
line
.
strip
()
if
line
.
startswith
(
'>'
):
if
line
.
startswith
(
">"
):
index
+=
1
descriptions
.
append
(
line
[
1
:])
# Remove the '>' at the beginning.
sequences
.
append
(
''
)
sequences
.
append
(
""
)
continue
elif
not
line
:
continue
# Skip blank lines.
...
...
@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return
sequences
,
descriptions
def
parse_stockholm
(
stockholm_string
:
str
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
def
parse_stockholm
(
stockholm_string
:
str
,
)
->
Tuple
[
Sequence
[
str
],
DeletionMatrix
,
Sequence
[
str
]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
...
...
@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence
=
collections
.
OrderedDict
()
for
line
in
stockholm_string
.
splitlines
():
line
=
line
.
strip
()
if
not
line
or
line
.
startswith
((
'#'
,
'
//
'
)):
if
not
line
or
line
.
startswith
((
"#"
,
"
//
"
)):
continue
name
,
sequence
=
line
.
split
()
if
name
not
in
name_to_sequence
:
name_to_sequence
[
name
]
=
''
name_to_sequence
[
name
]
=
""
name_to_sequence
[
name
]
+=
sequence
msa
=
[]
deletion_matrix
=
[]
query
=
''
query
=
""
keep_columns
=
[]
for
seq_index
,
sequence
in
enumerate
(
name_to_sequence
.
values
()):
if
seq_index
==
0
:
# Gather the columns with gaps from the query
query
=
sequence
keep_columns
=
[
i
for
i
,
res
in
enumerate
(
query
)
if
res
!=
'-'
]
keep_columns
=
[
i
for
i
,
res
in
enumerate
(
query
)
if
res
!=
"-"
]
# Remove the columns with gaps in the query from all sequences.
aligned_sequence
=
''
.
join
([
sequence
[
c
]
for
c
in
keep_columns
])
aligned_sequence
=
""
.
join
([
sequence
[
c
]
for
c
in
keep_columns
])
msa
.
append
(
aligned_sequence
)
...
...
@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec
=
[]
deletion_count
=
0
for
seq_res
,
query_res
in
zip
(
sequence
,
query
):
if
seq_res
!=
'-'
or
query_res
!=
'-'
:
if
query_res
==
'-'
:
if
seq_res
!=
"-"
or
query_res
!=
"-"
:
if
query_res
==
"-"
:
deletion_count
+=
1
else
:
deletion_vec
.
append
(
deletion_count
)
...
...
@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix
.
append
(
deletion_vec
)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table
=
str
.
maketrans
(
''
,
''
,
string
.
ascii_lowercase
)
deletion_table
=
str
.
maketrans
(
""
,
""
,
string
.
ascii_lowercase
)
aligned_sequences
=
[
s
.
translate
(
deletion_table
)
for
s
in
sequences
]
return
aligned_sequences
,
deletion_matrix
def
_convert_sto_seq_to_a3m
(
query_non_gaps
:
Sequence
[
bool
],
sto_seq
:
str
)
->
Iterable
[
str
]:
query_non_gaps
:
Sequence
[
bool
],
sto_seq
:
str
)
->
Iterable
[
str
]:
for
is_query_res_non_gap
,
sequence_res
in
zip
(
query_non_gaps
,
sto_seq
):
if
is_query_res_non_gap
:
yield
sequence_res
elif
sequence_res
!=
'-'
:
elif
sequence_res
!=
"-"
:
yield
sequence_res
.
lower
()
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
str
:
def
convert_stockholm_to_a3m
(
stockholm_format
:
str
,
max_sequences
:
Optional
[
int
]
=
None
)
->
str
:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions
=
{}
sequences
=
{}
reached_max_sequences
=
False
for
line
in
stockholm_format
.
splitlines
():
reached_max_sequences
=
max_sequences
and
len
(
sequences
)
>=
max_sequences
if
line
.
strip
()
and
not
line
.
startswith
((
'#'
,
'//'
)):
reached_max_sequences
=
(
max_sequences
and
len
(
sequences
)
>=
max_sequences
)
if
line
.
strip
()
and
not
line
.
startswith
((
"#"
,
"//"
)):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname
,
aligned_seq
=
line
.
split
(
maxsplit
=
1
)
if
seqname
not
in
sequences
:
if
reached_max_sequences
:
continue
sequences
[
seqname
]
=
''
sequences
[
seqname
]
=
""
sequences
[
seqname
]
+=
aligned_seq
for
line
in
stockholm_format
.
splitlines
():
if
line
[:
4
]
==
'
#=GS
'
:
if
line
[:
4
]
==
"
#=GS
"
:
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns
=
line
.
split
(
maxsplit
=
3
)
seqname
,
feature
=
columns
[
1
:
3
]
value
=
columns
[
3
]
if
len
(
columns
)
==
4
else
''
if
feature
!=
'
DE
'
:
value
=
columns
[
3
]
if
len
(
columns
)
==
4
else
""
if
feature
!=
"
DE
"
:
continue
if
reached_max_sequences
and
seqname
not
in
sequences
:
continue
...
...
@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences
=
{}
# query_sequence is assumed to be the first sequence
query_sequence
=
next
(
iter
(
sequences
.
values
()))
query_non_gaps
=
[
res
!=
'-'
for
res
in
query_sequence
]
query_non_gaps
=
[
res
!=
"-"
for
res
in
query_sequence
]
for
seqname
,
sto_sequence
in
sequences
.
items
():
a3m_sequences
[
seqname
]
=
''
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
))
a3m_sequences
[
seqname
]
=
""
.
join
(
_convert_sto_seq_to_a3m
(
query_non_gaps
,
sto_sequence
)
)
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
for
k
in
a3m_sequences
)
return
'
\n
'
.
join
(
fasta_chunks
)
+
'
\n
'
# Include terminating newline.
fasta_chunks
=
(
f
">
{
k
}
{
descriptions
.
get
(
k
,
''
)
}
\n
{
a3m_sequences
[
k
]
}
"
for
k
in
a3m_sequences
)
return
"
\n
"
.
join
(
fasta_chunks
)
+
"
\n
"
# Include terminating newline.
def
_get_hhr_line_regex_groups
(
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
regex_pattern
:
str
,
line
:
str
)
->
Sequence
[
Optional
[
str
]]:
match
=
re
.
match
(
regex_pattern
,
line
)
if
match
is
None
:
raise
RuntimeError
(
f
'
Could not parse query line
{
line
}
'
)
raise
RuntimeError
(
f
"
Could not parse query line
{
line
}
"
)
return
match
.
groups
()
def
_update_hhr_residue_indices_list
(
sequence
:
str
,
start_index
:
int
,
indices_list
:
List
[
int
]):
sequence
:
str
,
start_index
:
int
,
indices_list
:
List
[
int
]
):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter
=
start_index
for
symbol
in
sequence
:
if
symbol
==
'-'
:
if
symbol
==
"-"
:
indices_list
.
append
(
-
1
)
else
:
indices_list
.
append
(
counter
)
...
...
@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line.
pattern
=
(
'Probab=(.*)[
\t
]*E-value=(.*)[
\t
]*Score=(.*)[
\t
]*Aligned_cols=(.*)[
\t
'
' ]*Identities=(.*)%[
\t
]*Similarity=(.*)[
\t
]*Sum_probs=(.*)[
\t
'
']*Template_Neff=(.*)'
)
"Probab=(.*)[
\t
]*E-value=(.*)[
\t
]*Score=(.*)[
\t
]*Aligned_cols=(.*)[
\t
"
" ]*Identities=(.*)%[
\t
]*Similarity=(.*)[
\t
]*Sum_probs=(.*)[
\t
"
"]*Template_Neff=(.*)"
)
match
=
re
.
match
(
pattern
,
detailed_lines
[
2
])
if
match
is
None
:
raise
RuntimeError
(
'Could not parse section: %s. Expected this:
\n
%s to contain summary.'
%
(
detailed_lines
,
detailed_lines
[
2
]))
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
float
(
x
)
for
x
in
match
.
groups
()]
"Could not parse section: %s. Expected this:
\n
%s to contain summary."
%
(
detailed_lines
,
detailed_lines
[
2
])
)
(
prob_true
,
e_value
,
_
,
aligned_cols
,
_
,
_
,
sum_probs
,
neff
)
=
[
float
(
x
)
for
x
in
match
.
groups
()
]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query
=
''
hit_sequence
=
''
query
=
""
hit_sequence
=
""
indices_query
=
[]
indices_hit
=
[]
length_block
=
None
for
line
in
detailed_lines
[
3
:]:
# Parse the query sequence line
if
(
line
.
startswith
(
'Q '
)
and
not
line
.
startswith
(
'Q ss_dssp'
)
and
not
line
.
startswith
(
'Q ss_pred'
)
and
not
line
.
startswith
(
'Q Consensus'
)):
if
(
line
.
startswith
(
"Q "
)
and
not
line
.
startswith
(
"Q ss_dssp"
)
and
not
line
.
startswith
(
"Q ss_pred"
)
and
not
line
.
startswith
(
"Q Consensus"
)
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt
=
r
'
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)
'
patt
=
r
"
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)
"
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
# Get the length of the parsed block using the start and finish indices,
...
...
@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
delta_query
=
groups
[
1
]
end
=
int
(
groups
[
2
])
num_insertions
=
len
([
x
for
x
in
delta_query
if
x
==
'-'
])
num_insertions
=
len
([
x
for
x
in
delta_query
if
x
==
"-"
])
length_block
=
end
-
start
+
num_insertions
assert
length_block
==
len
(
delta_query
)
...
...
@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query
+=
delta_query
_update_hhr_residue_indices_list
(
delta_query
,
start
,
indices_query
)
elif
line
.
startswith
(
'
T
'
):
elif
line
.
startswith
(
"
T
"
):
# Parse the hit sequence.
if
(
not
line
.
startswith
(
'T ss_dssp'
)
and
not
line
.
startswith
(
'T ss_pred'
)
and
not
line
.
startswith
(
'T Consensus'
)):
if
(
not
line
.
startswith
(
"T ss_dssp"
)
and
not
line
.
startswith
(
"T ss_pred"
)
and
not
line
.
startswith
(
"T Consensus"
)
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt
=
r
'
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)
'
patt
=
r
"
[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)
"
groups
=
_get_hhr_line_regex_groups
(
patt
,
line
[
17
:])
start
=
int
(
groups
[
0
])
-
1
# Make index zero based.
delta_hit_sequence
=
groups
[
1
]
...
...
@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list.
hit_sequence
+=
delta_hit_sequence
_update_hhr_residue_indices_list
(
delta_hit_sequence
,
start
,
indices_hit
)
_update_hhr_residue_indices_list
(
delta_hit_sequence
,
start
,
indices_hit
)
return
TemplateHit
(
index
=
number_of_hit
,
...
...
@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts
=
[
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
startswith
(
'
No
'
)]
block_starts
=
[
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
startswith
(
"
No
"
)]
hits
=
[]
if
block_starts
:
block_starts
.
append
(
len
(
lines
))
# Add the end of the final block.
for
i
in
range
(
len
(
block_starts
)
-
1
):
hits
.
append
(
_parse_hhr_hit
(
lines
[
block_starts
[
i
]:
block_starts
[
i
+
1
]]))
hits
.
append
(
_parse_hhr_hit
(
lines
[
block_starts
[
i
]
:
block_starts
[
i
+
1
]])
)
return
hits
def
parse_e_values_from_tblout
(
tblout
:
str
)
->
Dict
[
str
,
float
]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values
=
{
'
query
'
:
0
}
lines
=
[
line
for
line
in
tblout
.
splitlines
()
if
line
[
0
]
!=
'#'
]
e_values
=
{
"
query
"
:
0
}
lines
=
[
line
for
line
in
tblout
.
splitlines
()
if
line
[
0
]
!=
"#"
]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
...
...
openfold/data/templates.py
View file @
07e64267
This diff is collapsed.
Click to expand it.
openfold/data/tools/hhblits.py
View file @
07e64267
...
...
@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500
class
HHBlits
:
"""Python wrapper of the HHblits binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
...
...
@@ -44,7 +45,8 @@ class HHBlits:
all_seqs
:
bool
=
False
,
alt
:
Optional
[
int
]
=
None
,
p
:
int
=
_HHBLITS_DEFAULT_P
,
z
:
int
=
_HHBLITS_DEFAULT_Z
):
z
:
int
=
_HHBLITS_DEFAULT_Z
,
):
"""Initializes the Python HHblits wrapper.
Args:
...
...
@@ -77,9 +79,13 @@ class HHBlits:
self
.
databases
=
databases
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
logging
.
error
(
'Could not find HHBlits database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find HHBlits database
{
database_path
}
'
)
if
not
glob
.
glob
(
database_path
+
"_*"
):
logging
.
error
(
"Could not find HHBlits database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find HHBlits database
{
database_path
}
"
)
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
...
...
@@ -95,52 +101,66 @@ class HHBlits:
def
query
(
self
,
input_fasta_path
:
str
)
->
Mapping
[
str
,
Any
]:
"""Queries the database using HHblits."""
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.a3m
'
)
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.a3m
"
)
db_cmd
=
[]
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'
-d
'
)
db_cmd
.
append
(
"
-d
"
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
'-i'
,
input_fasta_path
,
'-cpu'
,
str
(
self
.
n_cpu
),
'-oa3m'
,
a3m_path
,
'-o'
,
'/dev/null'
,
'-n'
,
str
(
self
.
n_iter
),
'-e'
,
str
(
self
.
e_value
),
'-maxseq'
,
str
(
self
.
maxseq
),
'-realign_max'
,
str
(
self
.
realign_max
),
'-maxfilt'
,
str
(
self
.
maxfilt
),
'-min_prefilter_hits'
,
str
(
self
.
min_prefilter_hits
)]
"-i"
,
input_fasta_path
,
"-cpu"
,
str
(
self
.
n_cpu
),
"-oa3m"
,
a3m_path
,
"-o"
,
"/dev/null"
,
"-n"
,
str
(
self
.
n_iter
),
"-e"
,
str
(
self
.
e_value
),
"-maxseq"
,
str
(
self
.
maxseq
),
"-realign_max"
,
str
(
self
.
realign_max
),
"-maxfilt"
,
str
(
self
.
maxfilt
),
"-min_prefilter_hits"
,
str
(
self
.
min_prefilter_hits
),
]
if
self
.
all_seqs
:
cmd
+=
[
'
-all
'
]
cmd
+=
[
"
-all
"
]
if
self
.
alt
:
cmd
+=
[
'
-alt
'
,
str
(
self
.
alt
)]
cmd
+=
[
"
-alt
"
,
str
(
self
.
alt
)]
if
self
.
p
!=
_HHBLITS_DEFAULT_P
:
cmd
+=
[
'
-p
'
,
str
(
self
.
p
)]
cmd
+=
[
"
-p
"
,
str
(
self
.
p
)]
if
self
.
z
!=
_HHBLITS_DEFAULT_Z
:
cmd
+=
[
'
-Z
'
,
str
(
self
.
z
)]
cmd
+=
[
"
-Z
"
,
str
(
self
.
z
)]
cmd
+=
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'
HHblits query
'
):
with
utils
.
timing
(
"
HHblits query
"
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
# Logs have a 15k character limit, so log HHblits error line by line.
logging
.
error
(
'
HHblits failed. HHblits stderr begin:
'
)
for
error_line
in
stderr
.
decode
(
'
utf-8
'
).
splitlines
():
logging
.
error
(
"
HHblits failed. HHblits stderr begin:
"
)
for
error_line
in
stderr
.
decode
(
"
utf-8
"
).
splitlines
():
if
error_line
.
strip
():
logging
.
error
(
error_line
.
strip
())
logging
.
error
(
'HHblits stderr end'
)
raise
RuntimeError
(
'HHblits failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
500_000
].
decode
(
'utf-8'
)))
logging
.
error
(
"HHblits stderr end"
)
raise
RuntimeError
(
"HHblits failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
[:
500_000
].
decode
(
"utf-8"
))
)
with
open
(
a3m_path
)
as
f
:
a3m
=
f
.
read
()
...
...
@@ -150,5 +170,6 @@ class HHBlits:
output
=
stdout
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
e_value
=
self
.
e_value
,
)
return
raw_output
openfold/data/tools/hhsearch.py
View file @
07e64267
...
...
@@ -26,12 +26,14 @@ from openfold.data.np import utils
class
HHSearch
:
"""Python wrapper of the HHsearch binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
n_cpu
:
int
=
2
,
maxseq
:
int
=
1_000_000
):
maxseq
:
int
=
1_000_000
,
):
"""Initializes the Python HHsearch wrapper.
Args:
...
...
@@ -52,41 +54,52 @@ class HHSearch:
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
if
not
glob
.
glob
(
database_path
+
'_*'
):
logging
.
error
(
'Could not find HHsearch database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find HHsearch database
{
database_path
}
'
)
if
not
glob
.
glob
(
database_path
+
"_*"
):
logging
.
error
(
"Could not find HHsearch database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find HHsearch database
{
database_path
}
"
)
def
query
(
self
,
a3m
:
str
)
->
str
:
"""Queries the database using HHsearch using a given a3m."""
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
query.a3m
'
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.hhr
'
)
with
open
(
input_path
,
'w'
)
as
f
:
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
input_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
query.a3m
"
)
hhr_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.hhr
"
)
with
open
(
input_path
,
"w"
)
as
f
:
f
.
write
(
a3m
)
db_cmd
=
[]
for
db_path
in
self
.
databases
:
db_cmd
.
append
(
'
-d
'
)
db_cmd
.
append
(
"
-d
"
)
db_cmd
.
append
(
db_path
)
cmd
=
[
self
.
binary_path
,
'-i'
,
input_path
,
'-o'
,
hhr_path
,
'-maxseq'
,
str
(
self
.
maxseq
),
'-cpu'
,
str
(
self
.
n_cpu
),
cmd
=
[
self
.
binary_path
,
"-i"
,
input_path
,
"-o"
,
hhr_path
,
"-maxseq"
,
str
(
self
.
maxseq
),
"-cpu"
,
str
(
self
.
n_cpu
),
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'HHsearch query'
):
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
"HHsearch query"
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
# Stderr is truncated to prevent proto size errors in Beam.
raise
RuntimeError
(
'HHSearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
[:
100_000
].
decode
(
'utf-8'
)))
"HHSearch failed:
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
[:
100_000
].
decode
(
"utf-8"
))
)
with
open
(
hhr_path
)
as
f
:
hhr
=
f
.
read
()
...
...
openfold/data/tools/jackhmmer.py
View file @
07e64267
...
...
@@ -29,7 +29,8 @@ from openfold.data.tools import utils
class
Jackhmmer
:
"""Python wrapper of the Jackhmmer binary."""
def
__init__
(
self
,
def
__init__
(
self
,
*
,
binary_path
:
str
,
database_path
:
str
,
...
...
@@ -44,7 +45,8 @@ class Jackhmmer:
incdom_e
:
Optional
[
float
]
=
None
,
dom_e
:
Optional
[
float
]
=
None
,
num_streamed_chunks
:
Optional
[
int
]
=
None
,
streaming_callback
:
Optional
[
Callable
[[
int
],
None
]]
=
None
):
streaming_callback
:
Optional
[
Callable
[[
int
],
None
]]
=
None
,
):
"""Initializes the Python Jackhmmer wrapper.
Args:
...
...
@@ -69,9 +71,14 @@ class Jackhmmer:
self
.
database_path
=
database_path
self
.
num_streamed_chunks
=
num_streamed_chunks
if
not
os
.
path
.
exists
(
self
.
database_path
)
and
num_streamed_chunks
is
None
:
logging
.
error
(
'Could not find Jackhmmer database %s'
,
database_path
)
raise
ValueError
(
f
'Could not find Jackhmmer database
{
database_path
}
'
)
if
(
not
os
.
path
.
exists
(
self
.
database_path
)
and
num_streamed_chunks
is
None
):
logging
.
error
(
"Could not find Jackhmmer database %s"
,
database_path
)
raise
ValueError
(
f
"Could not find Jackhmmer database
{
database_path
}
"
)
self
.
n_cpu
=
n_cpu
self
.
n_iter
=
n_iter
...
...
@@ -85,11 +92,12 @@ class Jackhmmer:
self
.
get_tblout
=
get_tblout
self
.
streaming_callback
=
streaming_callback
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
def
_query_chunk
(
self
,
input_fasta_path
:
str
,
database_path
:
str
)
->
Mapping
[
str
,
Any
]:
"""Queries the database chunk using Jackhmmer."""
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.sto
'
)
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
sto_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.sto
"
)
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
...
...
@@ -98,48 +106,63 @@ class Jackhmmer:
# amount of time.
cmd_flags
=
[
# Don't pollute stdout with Jackhmmer output.
'-o'
,
'/dev/null'
,
'-A'
,
sto_path
,
'--noali'
,
'--F1'
,
str
(
self
.
filter_f1
),
'--F2'
,
str
(
self
.
filter_f2
),
'--F3'
,
str
(
self
.
filter_f3
),
'--incE'
,
str
(
self
.
e_value
),
"-o"
,
"/dev/null"
,
"-A"
,
sto_path
,
"--noali"
,
"--F1"
,
str
(
self
.
filter_f1
),
"--F2"
,
str
(
self
.
filter_f2
),
"--F3"
,
str
(
self
.
filter_f3
),
"--incE"
,
str
(
self
.
e_value
),
# Report only sequences with E-values <= x in per-sequence output.
'-E'
,
str
(
self
.
e_value
),
'--cpu'
,
str
(
self
.
n_cpu
),
'-N'
,
str
(
self
.
n_iter
)
"-E"
,
str
(
self
.
e_value
),
"--cpu"
,
str
(
self
.
n_cpu
),
"-N"
,
str
(
self
.
n_iter
),
]
if
self
.
get_tblout
:
tblout_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
tblout.txt
'
)
cmd_flags
.
extend
([
'
--tblout
'
,
tblout_path
])
tblout_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
tblout.txt
"
)
cmd_flags
.
extend
([
"
--tblout
"
,
tblout_path
])
if
self
.
z_value
:
cmd_flags
.
extend
([
'
-Z
'
,
str
(
self
.
z_value
)])
cmd_flags
.
extend
([
"
-Z
"
,
str
(
self
.
z_value
)])
if
self
.
dom_e
is
not
None
:
cmd_flags
.
extend
([
'
--domE
'
,
str
(
self
.
dom_e
)])
cmd_flags
.
extend
([
"
--domE
"
,
str
(
self
.
dom_e
)])
if
self
.
incdom_e
is
not
None
:
cmd_flags
.
extend
([
'
--incdomE
'
,
str
(
self
.
incdom_e
)])
cmd_flags
.
extend
([
"
--incdomE
"
,
str
(
self
.
incdom_e
)])
cmd
=
[
self
.
binary_path
]
+
cmd_flags
+
[
input_fasta_path
,
database_path
]
cmd
=
(
[
self
.
binary_path
]
+
cmd_flags
+
[
input_fasta_path
,
database_path
]
)
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
f
'Jackhmmer (
{
os
.
path
.
basename
(
database_path
)
}
) query'
):
f
"Jackhmmer (
{
os
.
path
.
basename
(
database_path
)
}
) query"
):
_
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
if
retcode
:
raise
RuntimeError
(
'Jackhmmer failed
\n
stderr:
\n
%s
\n
'
%
stderr
.
decode
(
'utf-8'
))
"Jackhmmer failed
\n
stderr:
\n
%s
\n
"
%
stderr
.
decode
(
"utf-8"
)
)
# Get e-values for each target name
tbl
=
''
tbl
=
""
if
self
.
get_tblout
:
with
open
(
tblout_path
)
as
f
:
tbl
=
f
.
read
()
...
...
@@ -152,7 +175,8 @@ class Jackhmmer:
tbl
=
tbl
,
stderr
=
stderr
,
n_iter
=
self
.
n_iter
,
e_value
=
self
.
e_value
)
e_value
=
self
.
e_value
,
)
return
raw_output
...
...
@@ -162,15 +186,15 @@ class Jackhmmer:
return
[
self
.
_query_chunk
(
input_fasta_path
,
self
.
database_path
)]
db_basename
=
os
.
path
.
basename
(
self
.
database_path
)
db_remote_chunk
=
lambda
db_idx
:
f
'
{
self
.
database_path
}
.
{
db_idx
}
'
db_local_chunk
=
lambda
db_idx
:
f
'
/tmp/ramdisk/
{
db_basename
}
.
{
db_idx
}
'
db_remote_chunk
=
lambda
db_idx
:
f
"
{
self
.
database_path
}
.
{
db_idx
}
"
db_local_chunk
=
lambda
db_idx
:
f
"
/tmp/ramdisk/
{
db_basename
}
.
{
db_idx
}
"
# Remove existing files to prevent OOM
for
f
in
glob
.
glob
(
db_local_chunk
(
'
[0-9]*
'
)):
for
f
in
glob
.
glob
(
db_local_chunk
(
"
[0-9]*
"
)):
try
:
os
.
remove
(
f
)
except
OSError
:
print
(
f
'
OSError while deleting
{
f
}
'
)
print
(
f
"
OSError while deleting
{
f
}
"
)
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with
futures
.
ThreadPoolExecutor
(
max_workers
=
2
)
as
executor
:
...
...
@@ -179,15 +203,22 @@ class Jackhmmer:
# Copy the chunk locally
if
i
==
1
:
future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
),
db_local_chunk
(
i
))
request
.
urlretrieve
,
db_remote_chunk
(
i
),
db_local_chunk
(
i
),
)
if
i
<
self
.
num_streamed_chunks
:
next_future
=
executor
.
submit
(
request
.
urlretrieve
,
db_remote_chunk
(
i
+
1
),
db_local_chunk
(
i
+
1
))
request
.
urlretrieve
,
db_remote_chunk
(
i
+
1
),
db_local_chunk
(
i
+
1
),
)
# Run Jackhmmer with the chunk
future
.
result
()
chunked_output
.
append
(
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
)))
self
.
_query_chunk
(
input_fasta_path
,
db_local_chunk
(
i
))
)
# Remove the local copy of the chunk
os
.
remove
(
db_local_chunk
(
i
))
...
...
openfold/data/tools/kalign.py
View file @
07e64267
...
...
@@ -25,12 +25,12 @@ from openfold.data.tools import utils
def
_to_a3m
(
sequences
:
Sequence
[
str
])
->
str
:
"""Converts sequences to an a3m file."""
names
=
[
'
sequence %d
'
%
i
for
i
in
range
(
1
,
len
(
sequences
)
+
1
)]
names
=
[
"
sequence %d
"
%
i
for
i
in
range
(
1
,
len
(
sequences
)
+
1
)]
a3m
=
[]
for
sequence
,
name
in
zip
(
sequences
,
names
):
a3m
.
append
(
u
'>'
+
name
+
u
'
\n
'
)
a3m
.
append
(
sequence
+
u
'
\n
'
)
return
''
.
join
(
a3m
)
a3m
.
append
(
u
">"
+
name
+
u
"
\n
"
)
a3m
.
append
(
sequence
+
u
"
\n
"
)
return
""
.
join
(
a3m
)
class
Kalign
:
...
...
@@ -63,40 +63,51 @@ class Kalign:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging
.
info
(
'
Aligning %d sequences
'
,
len
(
sequences
))
logging
.
info
(
"
Aligning %d sequences
"
,
len
(
sequences
))
for
s
in
sequences
:
if
len
(
s
)
<
6
:
raise
ValueError
(
'Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).'
%
(
s
,
len
(
s
)))
raise
ValueError
(
"Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)."
%
(
s
,
len
(
s
))
)
with
utils
.
tmpdir_manager
(
base_dir
=
'
/tmp
'
)
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
input.fasta
'
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
'
output.a3m
'
)
with
utils
.
tmpdir_manager
(
base_dir
=
"
/tmp
"
)
as
query_tmp_dir
:
input_fasta_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
input.fasta
"
)
output_a3m_path
=
os
.
path
.
join
(
query_tmp_dir
,
"
output.a3m
"
)
with
open
(
input_fasta_path
,
'w'
)
as
f
:
with
open
(
input_fasta_path
,
"w"
)
as
f
:
f
.
write
(
_to_a3m
(
sequences
))
cmd
=
[
self
.
binary_path
,
'-i'
,
input_fasta_path
,
'-o'
,
output_a3m_path
,
'-format'
,
'fasta'
,
"-i"
,
input_fasta_path
,
"-o"
,
output_a3m_path
,
"-format"
,
"fasta"
,
]
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
logging
.
info
(
'Launching subprocess "%s"'
,
" "
.
join
(
cmd
))
process
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
with
utils
.
timing
(
'
Kalign query
'
):
with
utils
.
timing
(
"
Kalign query
"
):
stdout
,
stderr
=
process
.
communicate
()
retcode
=
process
.
wait
()
logging
.
info
(
'Kalign stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
,
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
))
logging
.
info
(
"Kalign stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
,
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
),
)
if
retcode
:
raise
RuntimeError
(
'Kalign failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
'
%
(
stdout
.
decode
(
'utf-8'
),
stderr
.
decode
(
'utf-8'
)))
raise
RuntimeError
(
"Kalign failed
\n
stdout:
\n
%s
\n\n
stderr:
\n
%s
\n
"
%
(
stdout
.
decode
(
"utf-8"
),
stderr
.
decode
(
"utf-8"
))
)
with
open
(
output_a3m_path
)
as
f
:
a3m
=
f
.
read
()
...
...
openfold/data/tools/utils.py
View file @
07e64267
...
...
@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None):
@
contextlib
.
contextmanager
def
timing
(
msg
:
str
):
logging
.
info
(
'
Started %s
'
,
msg
)
logging
.
info
(
"
Started %s
"
,
msg
)
tic
=
time
.
time
()
yield
toc
=
time
.
time
()
logging
.
info
(
'
Finished %s in %.3f seconds
'
,
msg
,
toc
-
tic
)
logging
.
info
(
"
Finished %s in %.3f seconds
"
,
msg
,
toc
-
tic
)
def
to_date
(
s
:
str
):
...
...
openfold/model/__init__.py
View file @
07e64267
...
...
@@ -3,13 +3,14 @@ import glob
import
importlib
as
importlib
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)]
_modules
=
[(
m
,
importlib
.
import_module
(
'.'
+
m
,
__name__
))
for
m
in
__all__
]
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
# Avoid needlessly cluttering the global namespace
del
_files
,
_m
,
_modules
openfold/model/dropout.py
View file @
07e64267
...
...
@@ -26,6 +26,7 @@ class Dropout(nn.Module):
If not in training mode, this module computes the identity function.
"""
def
__init__
(
self
,
r
:
float
,
batch_dim
:
Union
[
int
,
List
[
int
]]):
"""
Args:
...
...
@@ -37,7 +38,7 @@ class Dropout(nn.Module):
super
(
Dropout
,
self
).
__init__
()
self
.
r
=
r
if
(
type
(
batch_dim
)
==
int
)
:
if
type
(
batch_dim
)
==
int
:
batch_dim
=
[
batch_dim
]
self
.
batch_dim
=
batch_dim
self
.
dropout
=
nn
.
Dropout
(
self
.
r
)
...
...
@@ -50,7 +51,7 @@ class Dropout(nn.Module):
compatible with self.batch_dim
"""
shape
=
list
(
x
.
shape
)
if
(
self
.
batch_dim
is
not
None
)
:
if
self
.
batch_dim
is
not
None
:
for
bd
in
self
.
batch_dim
:
shape
[
bd
]
=
1
mask
=
x
.
new_ones
(
shape
)
...
...
@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout):
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
3
)
...
...
@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout):
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__
=
partialmethod
(
Dropout
.
__init__
,
batch_dim
=-
2
)
openfold/model/embedders.py
View file @
07e64267
...
...
@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module):
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
...
...
@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module):
self
.
no_bins
=
2
*
relpos_k
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
ri
:
torch
.
Tensor
):
def
relpos
(
self
,
ri
:
torch
.
Tensor
):
"""
Computes relative positional encodings
...
...
@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module):
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
self
,
def
forward
(
self
,
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
...
...
@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
min_bin
:
float
,
max_bin
:
float
,
no_bins
:
int
,
inf
:
float
=
1e8
,
**
kwargs
**
kwargs
,
):
"""
Args:
...
...
@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module):
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if
(
self
.
bins
is
None
)
:
if
self
.
bins
is
None
:
self
.
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
device
=
x
.
device
,
)
# [*, N, C_m]
...
...
@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module):
# couldn't find in time.
squared_bins
=
self
.
bins
**
2
upper
=
torch
.
cat
(
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])
],
dim
=-
1
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
)
d
=
torch
.
sum
(
(
x
[...,
None
,
:]
-
x
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdims
=
True
(
x
[...,
None
,
:]
-
x
[...,
None
,
:,
:])
**
2
,
dim
=-
1
,
keepdims
=
True
)
# [*, N, N, no_bins]
...
...
@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module):
Implements Algorithm 2, line 7.
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
**
kwargs
,
...
...
@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module):
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
c_out
,
self
.
c_out
,
init
=
"relu"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
...
...
@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module):
Implements Algorithm 2, line 9.
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
**
kwargs
,
...
...
@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module):
# Despite there being no relu nearby, the source uses that initializer
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
,
init
=
"relu"
)
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
**
kwargs
,
...
...
@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module):
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x:
...
...
openfold/model/evoformer.py
View file @
07e64267
...
...
@@ -45,6 +45,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
"""
def
__init__
(
self
,
c_m
,
n
,
chunk_size
):
"""
Args:
...
...
@@ -71,7 +72,8 @@ class MSATransition(nn.Module):
m
=
self
.
linear_2
(
m
)
*
mask
return
m
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -86,7 +88,7 @@ class MSATransition(nn.Module):
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if
(
mask
is
None
)
:
if
mask
is
None
:
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
])
mask
=
mask
.
unsqueeze
(
-
1
)
...
...
@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
(
self
.
chunk_size
is
not
None
)
:
if
self
.
chunk_size
is
not
None
:
m
=
chunk_layer
(
self
.
_transition
,
inp
,
...
...
@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
...
...
@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
inf
=
inf
,
)
if
(
_is_extra_msa_stack
)
:
if
_is_extra_msa_stack
:
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
...
...
@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module):
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
...
...
@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module):
Implements Algorithm 6.
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
...
...
@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module):
)
self
.
blocks
.
append
(
block
)
if
(
not
self
.
_is_extra_msa_stack
)
:
if
not
self
.
_is_extra_msa_stack
:
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
...
...
@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module):
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
)
for
b
in
self
.
blocks
],
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
s
=
None
if
(
not
self
.
_is_extra_msa_stack
)
:
if
not
self
.
_is_extra_msa_stack
:
seq_dim
=
-
3
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
...
...
@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def
__init__
(
self
,
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
...
...
@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module):
_is_extra_msa_stack
=
True
,
)
def
forward
(
self
,
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module):
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
_mask_trans
_mask_trans
=
_mask_trans
,
)
return
z
openfold/model/heads.py
View file @
07e64267
...
...
@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
**
config
[
"experimentally_resolved"
],
)
if
(
config
.
tm
.
enabled
)
:
if
config
.
tm
.
enabled
:
self
.
tm
=
TMScoreHead
(
**
config
.
tm
,
)
...
...
@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits
=
self
.
experimentally_resolved
(
outputs
[
"single"
]
)
aux_out
[
"experimentally_resolved_logits"
]
=
(
experimentally_resolved_logits
)
aux_out
[
"
experimentally_resolved_logits
"
]
=
experimentally_resolved_logits
if
(
self
.
config
.
tm
.
enabled
)
:
if
self
.
config
.
tm
.
enabled
:
tm_logits
=
self
.
tm
(
outputs
[
"pair"
])
aux_out
[
"tm_logits"
]
=
tm_logits
aux_out
[
"predicted_tm_score"
]
=
compute_tm
(
tm_logits
,
**
self
.
config
.
tm
)
aux_out
.
update
(
compute_predicted_aligned_error
(
tm_logits
,
**
self
.
config
.
tm
,
))
aux_out
.
update
(
compute_predicted_aligned_error
(
tm_logits
,
**
self
.
config
.
tm
,
)
)
return
aux_out
...
...
@@ -118,6 +121,7 @@ class DistogramHead(nn.Module):
For use in computation of distogram loss, subsection 1.9.8
"""
def
__init__
(
self
,
c_z
,
no_bins
,
**
kwargs
):
"""
Args:
...
...
@@ -133,9 +137,7 @@ class DistogramHead(nn.Module):
self
.
linear
=
Linear
(
self
.
c_z
,
self
.
no_bins
,
init
=
"final"
)
def
forward
(
self
,
z
# [*, N, N, C_z]
):
def
forward
(
self
,
z
):
# [*, N, N, C_z]
"""
Args:
z:
...
...
@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def
__init__
(
self
,
c_z
,
no_bins
,
**
kwargs
):
"""
Args:
...
...
@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def
__init__
(
self
,
c_m
,
c_out
,
**
kwargs
):
"""
Args:
...
...
@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module):
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def
__init__
(
self
,
c_s
,
c_out
,
**
kwargs
):
"""
Args:
...
...
openfold/model/model.py
View file @
07e64267
This diff is collapsed.
Click to expand it.
openfold/model/msa.py
View file @
07e64267
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment