Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d48c052c
Commit
d48c052c
authored
Oct 15, 2021
by
Gustaf Ahdritz
Browse files
Add training parsers
parent
eeda001c
Changes
23
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1676 additions
and
919 deletions
+1676
-919
openfold/config.py
openfold/config.py
+281
-233
openfold/features/data_pipeline.py
openfold/features/data_pipeline.py
+336
-0
openfold/features/data_transforms.py
openfold/features/data_transforms.py
+590
-47
openfold/features/feature_pipeline.py
openfold/features/feature_pipeline.py
+48
-17
openfold/features/input_pipeline.py
openfold/features/input_pipeline.py
+53
-24
openfold/features/mmcif_parsing.py
openfold/features/mmcif_parsing.py
+77
-2
openfold/features/np/hhsearch.py
openfold/features/np/hhsearch.py
+5
-1
openfold/features/np/jackhmmer.py
openfold/features/np/jackhmmer.py
+1
-3
openfold/features/np/utils.py
openfold/features/np/utils.py
+7
-0
openfold/features/templates.py
openfold/features/templates.py
+43
-53
openfold/model/model.py
openfold/model/model.py
+28
-29
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+9
-8
openfold/utils/feats.py
openfold/utils/feats.py
+6
-386
openfold/utils/loss.py
openfold/utils/loss.py
+8
-35
run_pretrained_openfold.py
run_pretrained_openfold.py
+23
-80
scripts/build_deepspeed_config.py
scripts/build_deepspeed_config.py
+1
-0
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+107
-0
scripts/utils.py
scripts/utils.py
+48
-0
tests/compare_utils.py
tests/compare_utils.py
+3
-0
tests/test_feats.py
tests/test_feats.py
+2
-1
No files found.
openfold/config.py
View file @
d48c052c
This diff is collapsed.
Click to expand it.
openfold/features/
np/
data_pipeline.py
→
openfold/features/data_pipeline.py
View file @
d48c052c
import
os
import
os
import
datetime
import
numpy
as
np
import
numpy
as
np
from
typing
import
Mapping
,
Optional
,
Sequence
from
typing
import
Mapping
,
Optional
,
Sequence
,
Any
from
openfold.features
import
templates
,
parsers
from
openfold.features
import
templates
,
parsers
,
mmcif_parsing
from
openfold.features.np
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.features.np
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.features.np.utils
import
to_date
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
def
make_sequence_features
(
sequence
:
str
,
description
:
str
,
num_res
:
int
)
->
FeatureDict
:
"""Construct a feature dict of sequence features."""
"""Construct a feature dict of sequence features."""
features
=
{}
features
=
{}
features
[
'aatype'
]
=
residue_constants
.
sequence_to_onehot
(
features
[
'aatype'
]
=
residue_constants
.
sequence_to_onehot
(
...
@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
...
@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
map_unknown_to_x
=
True
map_unknown_to_x
=
True
)
)
features
[
'between_segment_residues'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'between_segment_residues'
]
=
np
.
zeros
((
num_res
,),
dtype
=
np
.
int32
)
features
[
'domain_name'
]
=
np
.
array
([
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'domain_name'
]
=
np
.
array
(
[
description
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'residue_index'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'residue_index'
]
=
np
.
array
(
range
(
num_res
),
dtype
=
np
.
int32
)
features
[
'seq_length'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'seq_length'
]
=
np
.
array
([
num_res
]
*
num_res
,
dtype
=
np
.
int32
)
features
[
'sequence'
]
=
np
.
array
([
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
features
[
'sequence'
]
=
np
.
array
(
[
sequence
.
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
features
return
features
def
make_mmcif_features
(
mmcif_object
:
mmcif_parsing
.
MmcifObject
,
chain_id
:
str
)
->
FeatureDict
:
input_sequence
=
mmcif_object
.
chain_to_seqres
[
chain_id
]
description
=
'_'
.
join
([
mmcif_object
.
file_id
,
chain_id
])
num_res
=
len
(
input_sequence
)
mmcif_feats
=
{}
mmcif_feats
.
update
(
make_sequence_features
(
sequence
=
input_sequence
,
description
=
description
,
num_res
=
num_res
,
))
all_atom_positions
,
all_atom_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
chain_id
)
mmcif_feats
[
"all_atom_positions"
]
=
all_atom_positions
mmcif_feats
[
"all_atom_mask"
]
=
all_atom_mask
mmcif_feats
[
"resolution"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"resolution"
]],
dtype
=
np
.
float32
)
mmcif_feats
[
"release_date"
]
=
np
.
array
(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
'utf-8'
)],
dtype
=
np
.
object_
)
return
mmcif_feats
def
make_msa_features
(
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
])
->
FeatureDict
:
...
@@ -58,9 +101,9 @@ def make_msa_features(
...
@@ -58,9 +101,9 @@ def make_msa_features(
)
)
return
features
return
features
class
DataPipeline
:
"""Runs the alignment tools and assembles the input features."""
class
AlignmentRunner
:
""" Runs alignment tools and saves the results """
def
__init__
(
self
,
def
__init__
(
self
,
jackhmmer_binary_path
:
str
,
jackhmmer_binary_path
:
str
,
hhblits_binary_path
:
str
,
hhblits_binary_path
:
str
,
...
@@ -71,106 +114,158 @@ class DataPipeline:
...
@@ -71,106 +114,158 @@ class DataPipeline:
uniclust30_database_path
:
Optional
[
str
],
uniclust30_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
small_bfd_database_path
:
Optional
[
str
],
pdb70_database_path
:
str
,
pdb70_database_path
:
str
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
use_small_bfd
:
bool
,
mgnify_max_hits
:
int
=
501
,
no_cpus
:
int
,
uniref_max_hits
:
int
=
10000
uniref_max_hits
:
int
=
10000
,
mgnify_max_hits
:
int
=
5000
,
):
):
"""Constructs a feature dict for a given FASTA file."""
self
.
_use_small_bfd
=
use_small_bfd
self
.
_use_small_bfd
=
use_small_bfd
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
self
.
jackhmmer_uniref90_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
binary_path
=
jackhmmer_binary_path
,
database_path
=
uniref90_database_path
database_path
=
uniref90_database_path
,
n_cpu
=
no_cpus
,
)
)
if
use_small_bfd
:
if
use_small_bfd
:
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
self
.
jackhmmer_small_bfd_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
binary_path
=
jackhmmer_binary_path
,
database_path
=
small_bfd_database_path
database_path
=
small_bfd_database_path
,
n_cpu
=
no_cpus
,
)
)
else
:
else
:
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
self
.
hhblits_bfd_uniclust_runner
=
hhblits
.
HHBlits
(
binary_path
=
hhblits_binary_path
,
binary_path
=
hhblits_binary_path
,
databases
=
[
bfd_database_path
,
uniclust30_database_path
]
databases
=
[
bfd_database_path
,
uniclust30_database_path
],
n_cpu
=
no_cpus
,
)
)
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
self
.
jackhmmer_mgnify_runner
=
jackhmmer
.
Jackhmmer
(
binary_path
=
jackhmmer_binary_path
,
binary_path
=
jackhmmer_binary_path
,
database_path
=
mgnify_database_path
database_path
=
mgnify_database_path
,
n_cpu
=
no_cpus
,
)
)
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
self
.
hhsearch_pdb70_runner
=
hhsearch
.
HHSearch
(
binary_path
=
hhsearch_binary_path
,
binary_path
=
hhsearch_binary_path
,
databases
=
[
pdb70_database_path
]
databases
=
[
pdb70_database_path
]
)
)
self
.
template_featurizer
=
template_featurizer
self
.
mgnify_max_hits
=
mgnify_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
uniref_max_hits
=
uniref_max_hits
self
.
mgnify_max_hits
=
mgnify_max_hits
def
process
(
self
,
input_fasta_path
:
str
,
msa_output_dir
:
str
)
->
FeatureDict
:
def
run
(
self
,
"""Runs alignment tools on the input sequence and creates features."""
fasta_path
:
str
,
with
open
(
input_fasta_path
)
as
f
:
output_dir
:
str
,
input_fasta_str
=
f
.
read
()
):
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
input_fasta_str
)
"""Runs alignment tools on a sequence"""
if
len
(
input_seqs
)
!=
1
:
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
fasta_path
)[
0
]
raise
ValueError
(
f
'More than one input sequence found in
{
input_fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
jackhmmer_uniref90_result
=
self
.
jackhmmer_uniref90_runner
.
query
(
input_fasta_path
)[
0
]
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
input_fasta_path
)[
0
]
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
uniref90_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_uniref90_result
[
'sto'
],
max_sequences
=
self
.
uniref_max_hits
jackhmmer_uniref90_result
[
'sto'
],
max_sequences
=
self
.
uniref_max_hits
)
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
'uniref90_hits.a3m'
)
uniref90_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'uniref90_hits.sto'
)
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
with
open
(
uniref90_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_uniref90_result
[
'sto'
]
)
f
.
write
(
uniref90_msa_as_a3m
)
mgnify_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'mgnify_hits.so'
)
jackhmmer_mgnify_result
=
self
.
jackhmmer_mgnify_runner
.
query
(
fasta_path
)[
0
]
mgnify_msa_as_a3m
=
parsers
.
convert_stockholm_to_a3m
(
jackhmmer_mgnify_result
[
'sto'
],
max_sequences
=
self
.
mgnify_max_hits
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
with
open
(
mgnify_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_mgnify_result
[
'sto'
]
)
f
.
write
(
mgnify_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'pdb70_hits.hhr'
)
hhsearch_result
=
self
.
hhsearch_pdb70_runner
.
query
(
uniref90_msa_as_a3m
)
pdb70_out_path
=
os
.
path
.
join
(
output_dir
,
'pdb70_hits.hhr'
)
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
with
open
(
pdb70_out_path
,
'w'
)
as
f
:
f
.
write
(
hhsearch_result
)
f
.
write
(
hhsearch_result
)
uniref90_msa
,
uniref90_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_uniref90_result
[
'sto'
]
)
mgnify_msa
,
mgnify_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_mgnify_result
[
'sto'
]
)
hhsearch_hits
=
parsers
.
parse_hhr
(
hhsearch_result
)
mgnify_msa
=
mgnify_msa
[:
self
.
mgnify_max_hits
]
mgnify_deletion_matrix
=
mgnify_deletion_matrix
[:
self
.
mgnify_max_hits
]
if
self
.
_use_small_bfd
:
if
self
.
_use_small_bfd
:
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
input_
fasta_path
)[
0
]
jackhmmer_small_bfd_result
=
self
.
jackhmmer_small_bfd_runner
.
query
(
fasta_path
)[
0
]
bfd_out_path
=
os
.
path
.
join
(
msa_
output_dir
,
'small_bfd_hits.
a3m
'
)
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'small_bfd_hits.
sto
'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
f
.
write
(
jackhmmer_small_bfd_result
[
'sto'
])
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
fasta_path
)
if
(
output_dir
is
not
None
):
bfd_out_path
=
os
.
path
.
join
(
output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
jackhmmer_small_bfd_result
[
'sto'
]
class
DataPipeline
:
"""Assembles input features."""
def
__init__
(
self
,
template_featurizer
:
templates
.
TemplateHitFeaturizer
,
use_small_bfd
:
bool
,
):
self
.
template_featurizer
=
template_featurizer
self
.
use_small_bfd
=
use_small_bfd
def
_parse_alignment_output
(
self
,
alignment_dir
:
str
,
)
->
Mapping
[
str
,
Any
]:
uniref90_out_path
=
os
.
path
.
join
(
alignment_dir
,
'uniref90_hits.a3m'
)
with
open
(
uniref90_out_path
,
'r'
)
as
f
:
uniref90_msa
,
uniref90_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
mgnify_out_path
=
os
.
path
.
join
(
alignment_dir
,
'mgnify_hits.a3m'
)
with
open
(
mgnify_out_path
,
'r'
)
as
f
:
mgnify_msa
,
mgnify_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
)
else
:
hhblits_bfd_uniclust_result
=
self
.
hhblits_bfd_uniclust_runner
.
query
(
input_fasta_path
)
bfd_out_path
=
os
.
path
.
join
(
msa_output_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'w'
)
as
f
:
f
.
write
(
hhblits_bfd_uniclust_result
[
'a3m'
])
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
pdb70_out_path
=
os
.
path
.
join
(
alignment_dir
,
'pdb70_hits.hhr'
)
hhblits_bfd_uniclust_result
[
'a3m'
]
with
open
(
pdb70_out_path
,
'r'
)
as
f
:
hhsearch_hits
=
parsers
.
parse_hhr
(
f
.
read
()
)
)
if
(
self
.
use_small_bfd
):
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'small_bfd_hits.sto'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
,
_
=
parsers
.
parse_stockholm
(
f
.
read
()
)
else
:
bfd_out_path
=
os
.
path
.
join
(
alignment_dir
,
'bfd_uniclust_hits.a3m'
)
with
open
(
bfd_out_path
,
'r'
)
as
f
:
bfd_msa
,
bfd_deletion_matrix
=
parsers
.
parse_a3m
(
f
.
read
()
)
return
{
'uniref90_msa'
:
uniref90_msa
,
'uniref90_deletion_matrix'
:
uniref90_deletion_matrix
,
'mgnify_msa'
:
mgnify_msa
,
'mgnify_deletion_matrix'
:
mgnify_deletion_matrix
,
'hhsearch_hits'
:
hhsearch_hits
,
'bfd_msa'
:
bfd_msa
,
'bfd_deletion_matrix'
:
bfd_deletion_matrix
,
}
def
process_fasta
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
"""Assembles features for a single sequence in a FASTA file"""
with
open
(
fasta_path
)
as
f
:
fasta_str
=
f
.
read
()
input_seqs
,
input_descs
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
raise
ValueError
(
f
'More than one input sequence found in
{
fasta_path
}
.'
)
input_sequence
=
input_seqs
[
0
]
input_description
=
input_descs
[
0
]
num_res
=
len
(
input_sequence
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
templates_result
=
self
.
template_featurizer
.
get_templates
(
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_pdb_code
=
None
,
query_release_date
=
None
,
query_release_date
=
None
,
hits
=
hhsearch_hits
hits
=
alignments
[
'
hhsearch_hits
'
]
)
)
sequence_features
=
make_sequence_features
(
sequence_features
=
make_sequence_features
(
...
@@ -180,9 +275,62 @@ class DataPipeline:
...
@@ -180,9 +275,62 @@ class DataPipeline:
)
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
msas
=
(
uniref90_msa
,
bfd_msa
,
mgnify_msa
),
msas
=
(
deletion_matrices
=
(
uniref90_deletion_matrix
,
alignments
[
'uniref90_msa'
],
bfd_deletion_matrix
,
alignments
[
'bfd_msa'
],
mgnify_deletion_matrix
)
alignments
[
'mgnify_msa'
]
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
def
process_mmcif
(
self
,
mmcif
:
mmcif_parsing
.
MmcifObject
,
# parsing is expensive, so no path
alignment_dir
:
str
,
chain_id
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if
(
chain_id
is
None
):
chains
=
mmcif
.
structure
.
get_chains
()
chain
=
next
(
chains
,
None
)
if
(
chain
is
None
):
raise
ValueError
(
'No chains in mmCIF file'
)
chain_id
=
chain
.
id
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
'hhsearch_hits'
]
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
'uniref90_msa'
],
alignments
[
'bfd_msa'
],
alignments
[
'mgnify_msa'
]
),
deletion_matrices
=
(
alignments
[
'uniref90_deletion_matrix'
],
alignments
[
'bfd_deletion_matrix'
],
alignments
[
'mgnify_deletion_matrix'
]
)
)
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
openfold/features/data_transforms.py
View file @
d48c052c
This diff is collapsed.
Click to expand it.
openfold/features/feature_pipeline.py
View file @
d48c052c
...
@@ -25,39 +25,67 @@ def np_to_tensor_dict(
...
@@ -25,39 +25,67 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
features are returned, all other ones are filtered out.
"""
"""
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
return
tensor_dict
return
tensor_dict
def
make_data_config
(
def
make_data_config
(
config
:
ml_collections
.
ConfigDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
num_res
:
int
,
num_res
:
int
,
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
)
->
Tuple
[
ml_collections
.
ConfigDict
,
List
[
str
]]:
cfg
=
copy
.
deepcopy
(
config
.
data
)
cfg
=
copy
.
deepcopy
(
config
)
mode_cfg
=
cfg
[
mode
]
with
cfg
.
unlocked
():
if
(
mode_cfg
.
crop_size
is
None
):
mode_cfg
.
crop_size
=
num_res
feature_names
=
cfg
.
common
.
unsupervised_features
feature_names
=
cfg
.
common
.
unsupervised_features
if
cfg
.
common
.
use_templates
:
if
cfg
.
common
.
use_templates
:
feature_names
+=
cfg
.
common
.
template_features
feature_names
+=
cfg
.
common
.
template_features
with
cfg
.
unlock
ed
(
):
if
(
cfg
[
mode
].
supervis
ed
):
cfg
.
eval
.
crop_size
=
num_
res
feature_names
+=
cfg
.
common
.
supervised_featu
res
return
cfg
,
feature_names
return
cfg
,
feature_names
def
np_example_to_features
(
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
def
np_example_to_features
(
random_seed
:
int
=
0
):
np_example
:
FeatureDict
,
config
:
ml_collections
.
ConfigDict
,
mode
:
str
,
batch_mode
:
str
,
):
np_example
=
dict
(
np_example
)
np_example
=
dict
(
np_example
)
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
num_res
=
int
(
np_example
[
'seq_length'
][
0
])
cfg
,
feature_names
=
make_data_config
(
config
,
num_res
=
num_res
)
cfg
,
feature_names
=
make_data_config
(
config
,
mode
=
mode
,
num_res
=
num_res
)
if
'deletion_matrix_int'
in
np_example
:
if
'deletion_matrix_int'
in
np_example
:
np_example
[
'deletion_matrix'
]
=
(
np_example
[
'deletion_matrix'
]
=
(
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
))
np_example
.
pop
(
'deletion_matrix_int'
).
astype
(
np
.
float32
)
)
if
batch_mode
==
'clamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
1.
).
astype
(
np
.
float32
)
)
elif
batch_mode
==
'unclamped'
:
np_example
[
'use_clamped_fape'
]
=
(
np
.
array
(
0.
).
astype
(
np
.
float32
)
)
torch
.
manual_seed
(
random_seed
)
tensor_dict
=
np_to_tensor_dict
(
tensor_dict
=
np_to_tensor_dict
(
np_example
=
np_example
,
features
=
feature_names
)
np_example
=
np_example
,
features
=
feature_names
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
)
)
with
torch
.
no_grad
():
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
batch_mode
=
batch_mode
,
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
...
@@ -70,10 +98,13 @@ class FeaturePipeline:
...
@@ -70,10 +98,13 @@ class FeaturePipeline:
self
.
params
=
params
self
.
params
=
params
def
process_features
(
self
,
def
process_features
(
self
,
raw_features
:
FeatureDict
,
raw_features
:
FeatureDict
,
random_seed
:
int
)
->
FeatureDict
:
mode
:
str
=
'train'
,
batch_mode
:
str
=
'clamped'
,
)
->
FeatureDict
:
return
np_example_to_features
(
return
np_example_to_features
(
np_example
=
raw_features
,
np_example
=
raw_features
,
config
=
self
.
config
,
config
=
self
.
config
,
random_seed
=
random_seed
mode
=
mode
,
)
batch_mode
=
batch_mode
,
\ No newline at end of file
)
openfold/features/input_pipeline.py
View file @
d48c052c
from
functools
import
partial
import
torch
import
torch
from
openfold.features
import
data_transforms
from
openfold.features
import
data_transforms
def
nonensembled_transform_fns
(
data_confi
g
):
def
nonensembled_transform_fns
(
common_cfg
,
mode_cf
g
):
"""Input pipeline data transformers that are not ensembled."""
"""Input pipeline data transformers that are not ensembled."""
common_cfg
=
data_config
.
common
transforms
=
[
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
correct_msa_restypes
,
data_transforms
.
correct_msa_restypes
,
...
@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
...
@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
data_transforms
.
make_template_mask
,
data_transforms
.
make_template_mask
,
data_transforms
.
make_pseudo_beta
(
'template_'
)
data_transforms
.
make_pseudo_beta
(
'template_'
)
])
])
if
(
common_cfg
.
use_template_torsion_angles
):
transforms
.
extend
([
data_transforms
.
atom37_to_torsion_angles
(
'template_'
),
])
transforms
.
extend
([
transforms
.
extend
([
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_masks
,
])
])
if
(
mode_cfg
.
supervised
):
transforms
.
extend
([
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
''
),
data_transforms
.
make_pseudo_beta
(
''
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
])
return
transforms
return
transforms
def
ensembled_transform_fns
(
data_config
):
def
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
):
"""Input pipeline data transformers that can be ensembled and averaged."""
"""Input pipeline data transformers that can be ensembled and averaged."""
common_cfg
=
data_config
.
common
eval_cfg
=
data_config
.
eval
transforms
=
[]
transforms
=
[]
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
pad_msa_clusters
=
eval
_cfg
.
max_msa_clusters
-
eval
_cfg
.
max_templates
pad_msa_clusters
=
mode
_cfg
.
max_msa_clusters
-
mode
_cfg
.
max_templates
else
:
else
:
pad_msa_clusters
=
eval
_cfg
.
max_msa_clusters
pad_msa_clusters
=
mode
_cfg
.
max_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_msa_clusters
=
pad_msa_clusters
max_extra_msa
=
common_cfg
.
max_extra_msa
max_extra_msa
=
common_cfg
.
max_extra_msa
...
@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
...
@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
# the clustering and full MSA profile do not leak information about
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
# the masked locations and secret corrupted locations.
transforms
.
append
(
transforms
.
append
(
data_transforms
.
make_masked_msa
(
common_cfg
.
masked_msa
,
data_transforms
.
make_masked_msa
(
eval_cfg
.
masked_msa_replace_fraction
)
common_cfg
.
masked_msa
,
mode_cfg
.
masked_msa_replace_fraction
)
)
)
if
common_cfg
.
msa_cluster_features
:
if
common_cfg
.
msa_cluster_features
:
...
@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
...
@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
transforms
.
append
(
data_transforms
.
make_msa_feat
())
transforms
.
append
(
data_transforms
.
make_msa_feat
())
crop_feats
=
dict
(
eval
_cfg
.
feat
)
crop_feats
=
dict
(
common
_cfg
.
feat
)
if
eval
_cfg
.
fixed_size
:
if
mode
_cfg
.
fixed_size
:
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
select_feat
(
list
(
crop_feats
)))
transforms
.
append
(
data_transforms
.
random_crop_to_size
(
mode_cfg
.
crop_size
,
mode_cfg
.
max_templates
,
crop_feats
,
mode_cfg
.
subsample_templates
,
batch_mode
=
batch_mode
,
seed
=
torch
.
Generator
().
seed
()
))
transforms
.
append
(
data_transforms
.
make_fixed_size
(
transforms
.
append
(
data_transforms
.
make_fixed_size
(
crop_feats
,
crop_feats
,
pad_msa_clusters
,
pad_msa_clusters
,
common_cfg
.
max_extra_msa
,
common_cfg
.
max_extra_msa
,
eval
_cfg
.
crop_size
,
mode
_cfg
.
crop_size
,
eval
_cfg
.
max_templates
mode
_cfg
.
max_templates
))
))
else
:
else
:
transforms
.
append
(
data_transforms
.
crop_templates
(
eval_cfg
.
max_templates
))
transforms
.
append
(
data_transforms
.
crop_templates
(
mode_cfg
.
max_templates
)
)
return
transforms
return
transforms
def
process_tensors_from_config
(
tensors
,
data_config
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
batch_mode
=
'clamped'
):
"""Based on the config, apply filters and transformations to the data."""
"""Based on the config, apply filters and transformations to the data."""
def
wrap_ensemble_fn
(
data
,
i
):
def
wrap_ensemble_fn
(
data
,
i
):
"""Function to be mapped over the ensemble dimension."""
"""Function to be mapped over the ensemble dimension."""
d
=
data
.
copy
()
d
=
data
.
copy
()
fns
=
ensembled_transform_fns
(
data_config
)
fns
=
ensembled_transform_fns
(
common_cfg
,
mode_cfg
,
batch_mode
)
fn
=
compose
(
fns
)
fn
=
compose
(
fns
)
d
[
'ensemble_index'
]
=
i
d
[
'ensemble_index'
]
=
i
return
fn
(
d
)
return
fn
(
d
)
eval_cfg
=
data_config
.
eval
tensors
=
compose
(
tensors
=
compose
(
nonensembled_transform_fns
(
data_confi
g
)
nonensembled_transform_fns
(
common_cfg
,
mode_cf
g
)
)(
tensors
)
)(
tensors
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
tensors_0
=
wrap_ensemble_fn
(
tensors
,
0
)
num_ensemble
=
eval
_cfg
.
num_ensemble
num_ensemble
=
mode
_cfg
.
num_ensemble
if
data_config
.
common
.
resample_msa_in_recycling
:
if
common
_cfg
.
resample_msa_in_recycling
:
# Separate batch per ensembling & recycling step.
# Separate batch per ensembling & recycling step.
num_ensemble
*=
data_config
.
common
.
num_recycle
+
1
num_ensemble
*=
common
_cfg
.
num_recycle
+
1
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
if
isinstance
(
num_ensemble
,
torch
.
Tensor
)
or
num_ensemble
>
1
:
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
tensors
=
map_fn
(
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
...
@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
...
@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
return
tensors
return
tensors
@
data_transforms
.
curry1
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
def
compose
(
x
,
fs
):
for
f
in
fs
:
for
f
in
fs
:
x
=
f
(
x
)
x
=
f
(
x
)
return
x
return
x
def
map_fn
(
fun
,
x
):
def
map_fn
(
fun
,
x
):
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
ensembles
=
[
fun
(
elem
)
for
elem
in
x
]
features
=
ensembles
[
0
].
keys
()
features
=
ensembles
[
0
].
keys
()
ensembled_dict
=
{}
ensembled_dict
=
{}
for
feat
in
features
:
for
feat
in
features
:
ensembled_dict
[
feat
]
=
torch
.
stack
([
dict_i
[
feat
]
for
dict_i
in
ensembles
])
ensembled_dict
[
feat
]
=
torch
.
stack
(
[
dict_i
[
feat
]
for
dict_i
in
ensembles
],
dim
=-
1
)
return
ensembled_dict
return
ensembled_dict
openfold/features/mmcif_parsing.py
View file @
d48c052c
"""Parses the mmCIF file format."""
"""Parses the mmCIF file format."""
import
collections
import
collections
import
dataclasses
import
dataclasses
import
io
import
io
import
json
import
logging
import
os
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
from
Bio
import
PDB
from
Bio
import
PDB
from
Bio.Data
import
SCOPData
from
Bio.Data
import
SCOPData
import
numpy
as
np
import
openfold.np.residue_constants
as
residue_constants
# Type aliases:
# Type aliases:
ChainId
=
str
ChainId
=
str
...
@@ -369,3 +374,73 @@ def _get_protein_chains(
...
@@ -369,3 +374,73 @@ def _get_protein_chains(
def
_is_set
(
data
:
str
)
->
bool
:
def
_is_set
(
data
:
str
)
->
bool
:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return
data
not
in
(
'.'
,
'?'
)
return
data
not
in
(
'.'
,
'?'
)
def
get_atom_coords
(
mmcif_object
:
MmcifObject
,
chain_id
:
str
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
# Locate the right chain
chains
=
list
(
mmcif_object
.
structure
.
get_chains
())
relevant_chains
=
[
c
for
c
in
chains
if
c
.
id
==
chain_id
]
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
# Extract the coordinates
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
chain_id
])
all_atom_positions
=
np
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
all_atom_mask
=
np
.
zeros
(
[
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coords of the selenium atom in the sulphur column
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
all_atom_positions
[
res_index
]
=
pos
all_atom_mask
[
res_index
]
=
mask
return
all_atom_positions
,
all_atom_mask
def
generate_mmcif_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
data
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'.cif'
)):
with
open
(
os
.
path
.
join
(
mmcif_dir
,
f
),
'r'
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Could not parse
{
f
}
. Skipping...'
)
continue
else
:
mmcif
=
mmcif
.
mmcif_object
local_data
=
{}
local_data
[
'release_date'
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
'no_chains'
]
=
len
(
list
(
mmcif
.
structure
.
get_chains
()))
data
[
file_id
]
=
local_data
with
open
(
out_path
,
'w'
)
as
fp
:
fp
.
write
(
json
.
dumps
(
data
))
openfold/features/np/hhsearch.py
View file @
d48c052c
...
@@ -18,6 +18,7 @@ class HHSearch:
...
@@ -18,6 +18,7 @@ class HHSearch:
*
,
*
,
binary_path
:
str
,
binary_path
:
str
,
databases
:
Sequence
[
str
],
databases
:
Sequence
[
str
],
n_cpu
:
int
=
2
,
maxseq
:
int
=
1_000_000
):
maxseq
:
int
=
1_000_000
):
"""Initializes the Python HHsearch wrapper.
"""Initializes the Python HHsearch wrapper.
...
@@ -26,6 +27,7 @@ class HHSearch:
...
@@ -26,6 +27,7 @@ class HHSearch:
databases: A sequence of HHsearch database paths. This should be the
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
parameter is only supported in HHBlits version 3.1 and higher.
...
@@ -34,6 +36,7 @@ class HHSearch:
...
@@ -34,6 +36,7 @@ class HHSearch:
"""
"""
self
.
binary_path
=
binary_path
self
.
binary_path
=
binary_path
self
.
databases
=
databases
self
.
databases
=
databases
self
.
n_cpu
=
n_cpu
self
.
maxseq
=
maxseq
self
.
maxseq
=
maxseq
for
database_path
in
self
.
databases
:
for
database_path
in
self
.
databases
:
...
@@ -56,7 +59,8 @@ class HHSearch:
...
@@ -56,7 +59,8 @@ class HHSearch:
cmd
=
[
self
.
binary_path
,
cmd
=
[
self
.
binary_path
,
'-i'
,
input_path
,
'-i'
,
input_path
,
'-o'
,
hhr_path
,
'-o'
,
hhr_path
,
'-maxseq'
,
str
(
self
.
maxseq
)
'-maxseq'
,
str
(
self
.
maxseq
),
'-cpu'
,
str
(
self
.
n_cpu
),
]
+
db_cmd
]
+
db_cmd
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
logging
.
info
(
'Launching subprocess "%s"'
,
' '
.
join
(
cmd
))
...
...
openfold/features/np/jackhmmer.py
View file @
d48c052c
...
@@ -3,14 +3,12 @@
...
@@ -3,14 +3,12 @@
from
concurrent
import
futures
from
concurrent
import
futures
import
glob
import
glob
import
logging
import
os
import
os
import
subprocess
import
subprocess
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Sequence
from
urllib
import
request
from
urllib
import
request
from
absl
import
logging
from
openfold.features.np
import
utils
from
openfold.features.np
import
utils
...
...
openfold/features/np/utils.py
View file @
d48c052c
"""Common utilities for data pipeline tools."""
"""Common utilities for data pipeline tools."""
import
contextlib
import
contextlib
import
datetime
import
shutil
import
shutil
import
tempfile
import
tempfile
import
time
import
time
...
@@ -25,3 +26,9 @@ def timing(msg: str):
...
@@ -25,3 +26,9 @@ def timing(msg: str):
yield
yield
toc
=
time
.
time
()
toc
=
time
.
time
()
logging
.
info
(
'Finished %s in %.3f seconds'
,
msg
,
toc
-
tic
)
logging
.
info
(
'Finished %s in %.3f seconds'
,
msg
,
toc
-
tic
)
def
to_date
(
s
:
str
):
return
datetime
.
datetime
(
year
=
int
(
s
[:
4
]),
month
=
int
(
s
[
5
:
7
]),
day
=
int
(
s
[
8
:
10
])
)
openfold/features/templates.py
View file @
d48c052c
...
@@ -2,16 +2,17 @@
...
@@ -2,16 +2,17 @@
import
dataclasses
import
dataclasses
import
datetime
import
datetime
import
glob
import
glob
import
json
import
logging
import
os
import
os
import
re
import
re
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Tuple
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
from
openfold.features
import
parsers
,
mmcif_parsing
from
openfold.features
import
parsers
,
mmcif_parsing
from
openfold.features.np
import
kalign
from
openfold.features.np
import
kalign
from
openfold.features.np.utils
import
to_date
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
...
@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
...
@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES
=
{
TEMPLATE_FEATURES
=
{
'template_aatype'
:
np
.
int64
,
'template_aatype'
:
np
.
int64
,
'template_all_atom_mask
s
'
:
np
.
float32
,
'template_all_atom_mask'
:
np
.
float32
,
'template_all_atom_positions'
:
np
.
float32
,
'template_all_atom_positions'
:
np
.
float32
,
'template_domain_names'
:
np
.
object
,
'template_domain_names'
:
np
.
object
,
'template_sequence'
:
np
.
object
,
'template_sequence'
:
np
.
object
,
...
@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
...
@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
return
result
return
result
def
generate_release_dates_cache
(
mmcif_dir
:
str
,
out_path
:
str
):
dates
=
{}
for
f
in
os
.
listdir
(
mmcif_dir
):
if
(
f
.
endswith
(
'.cif'
)):
path
=
os
.
path
.
join
(
mmcif_dir
,
f
)
with
open
(
path
,
'r'
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
f
}
. Skipping...'
)
continue
mmcif
=
mmcif
.
mmcif_object
release_date
=
mmcif
.
header
[
'release_date'
]
dates
[
file_id
]
=
release_date
with
open
(
out_path
,
'r'
)
as
fp
:
fp
.
write
(
json
.
dumps
(
dates
))
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
def
_parse_release_dates
(
path
:
str
)
->
Mapping
[
str
,
datetime
.
datetime
]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if
path
.
endswith
(
'txt'
):
with
open
(
path
,
'r'
)
as
fp
:
release_dates
=
{}
data
=
json
.
load
(
fp
)
with
open
(
path
,
'r'
)
as
f
:
for
line
in
f
:
pdb_id
,
date
=
line
.
split
(
':'
)
date
=
date
.
strip
()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates
[
pdb_id
.
strip
()]
=
datetime
.
datetime
(
year
=
int
(
date
[:
4
]),
month
=
int
(
date
[
5
:
7
]),
day
=
int
(
date
[
8
:
10
]))
return
release_dates
else
:
raise
ValueError
(
'Invalid format of the release date file %s.'
%
path
)
return
{
pdb
:
to_date
(
v
)
for
pdb
,
d
in
data
.
items
()
for
k
,
v
in
d
.
items
()
if
k
==
"release_date"
}
def
_assess_hhsearch_hit
(
def
_assess_hhsearch_hit
(
hit
:
parsers
.
TemplateHit
,
hit
:
parsers
.
TemplateHit
,
...
@@ -419,42 +437,14 @@ def _get_atom_positions(
...
@@ -419,42 +437,14 @@ def _get_atom_positions(
auth_chain_id
:
str
,
auth_chain_id
:
str
,
max_ca_ca_distance
:
float
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
max_ca_ca_distance
:
float
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""Gets atom positions and mask from a list of Biopython Residues."""
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res
=
len
(
mmcif_object
.
chain_to_seqres
[
auth_chain_id
])
coords_with_mask
=
mmcif_parsing
.
get_atom_coords
(
mmcif_object
=
mmcif_object
,
chain_id
=
auth_chain_id
relevant_chains
=
[
c
for
c
in
mmcif_object
.
structure
.
get_chains
()
)
if
c
.
id
==
auth_chain_id
]
all_atom_positions
,
all_atom_mask
=
coords_with_mask
if
len
(
relevant_chains
)
!=
1
:
raise
MultipleChainsError
(
f
'Expected exactly one chain in structure with id
{
auth_chain_id
}
.'
)
chain
=
relevant_chains
[
0
]
all_positions
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
,
3
])
all_positions_mask
=
np
.
zeros
([
num_res
,
residue_constants
.
atom_type_num
],
dtype
=
np
.
int64
)
for
res_index
in
range
(
num_res
):
pos
=
np
.
zeros
([
residue_constants
.
atom_type_num
,
3
],
dtype
=
np
.
float32
)
mask
=
np
.
zeros
([
residue_constants
.
atom_type_num
],
dtype
=
np
.
float32
)
res_at_position
=
mmcif_object
.
seqres_to_structure
[
auth_chain_id
][
res_index
]
if
not
res_at_position
.
is_missing
:
res
=
chain
[(
res_at_position
.
hetflag
,
res_at_position
.
position
.
residue_number
,
res_at_position
.
position
.
insertion_code
)]
for
atom
in
res
.
get_atoms
():
atom_name
=
atom
.
get_name
()
x
,
y
,
z
=
atom
.
get_coord
()
if
atom_name
in
residue_constants
.
atom_order
.
keys
():
pos
[
residue_constants
.
atom_order
[
atom_name
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
atom_name
]]
=
1.0
elif
atom_name
.
upper
()
==
'SE'
and
res
.
get_resname
()
==
'MSE'
:
# Put the coordinates of the selenium atom in the sulphur column.
pos
[
residue_constants
.
atom_order
[
'SD'
]]
=
[
x
,
y
,
z
]
mask
[
residue_constants
.
atom_order
[
'SD'
]]
=
1.0
all_positions
[
res_index
]
=
pos
all_positions_mask
[
res_index
]
=
mask
_check_residue_distances
(
_check_residue_distances
(
all_positions
,
all_positions_mask
,
max_ca_ca_distance
)
all_atom_positions
,
all_atom_mask
,
max_ca_ca_distance
return
all_positions
,
all_positions_mask
)
return
all_atom_positions
,
all_atom_mask
def
_extract_template_features
(
def
_extract_template_features
(
...
@@ -579,7 +569,7 @@ def _extract_template_features(
...
@@ -579,7 +569,7 @@ def _extract_template_features(
return
(
return
(
{
{
'template_all_atom_positions'
:
np
.
array
(
templates_all_atom_positions
),
'template_all_atom_positions'
:
np
.
array
(
templates_all_atom_positions
),
'template_all_atom_mask
s
'
:
np
.
array
(
templates_all_atom_masks
),
'template_all_atom_mask'
:
np
.
array
(
templates_all_atom_masks
),
'template_sequence'
:
output_templates_sequence
.
encode
(),
'template_sequence'
:
output_templates_sequence
.
encode
(),
'template_aatype'
:
np
.
array
(
templates_aatype
),
'template_aatype'
:
np
.
array
(
templates_aatype
),
'template_domain_names'
:
f
'
{
pdb_id
.
lower
()
}
_
{
chain_id
}
'
.
encode
(),
'template_domain_names'
:
f
'
{
pdb_id
.
lower
()
}
_
{
chain_id
}
'
.
encode
(),
...
...
openfold/model/model.py
View file @
d48c052c
...
@@ -19,7 +19,6 @@ import torch.nn as nn
...
@@ -19,7 +19,6 @@ import torch.nn as nn
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
pseudo_beta_fn
,
pseudo_beta_fn
,
atom37_to_torsion_angles
,
build_extra_msa_feat
,
build_extra_msa_feat
,
build_template_angle_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
build_template_pair_feat
,
...
@@ -115,21 +114,16 @@ class AlphaFold(nn.Module):
...
@@ -115,21 +114,16 @@ class AlphaFold(nn.Module):
batch
,
batch
,
)
)
# Build template angle feats
single_template_embeds
=
{}
angle_feats
=
atom37_to_torsion_angles
(
if
(
self
.
config
.
template
.
embed_angles
):
single_template_feats
[
"template_aatype"
],
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
[
"template_all_atom_positions"
],
#.float(),
single_template_feats
,
single_template_feats
[
"template_all_atom_masks"
],
#.float(),
)
eps
=
self
.
config
.
template
.
eps
,
)
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
single_template_feats
[
"template_aatype"
],
)
# [*, S_t, N, C_m]
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
t
=
build_template_pair_feat
(
...
@@ -145,11 +139,11 @@ class AlphaFold(nn.Module):
...
@@ -145,11 +139,11 @@ class AlphaFold(nn.Module):
_mask_trans
=
self
.
config
.
_mask_trans
_mask_trans
=
self
.
config
.
_mask_trans
)
)
template_embeds
.
append
({
single_template_embeds
.
update
({
"angle"
:
a
,
"pair"
:
t
,
"pair"
:
t
,
"torsion_mask"
:
angle_feats
[
"torsion_angles_mask"
]
})
})
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
partial
(
torch
.
cat
,
dim
=
templ_dim
),
...
@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
...
@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
)
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
return
{
ret
=
{}
"template_angle_embedding"
:
template_embeds
[
"angle"
],
if
(
self
.
config
.
template
.
embed_angles
):
ret
[
"template_angle_embedding"
]
=
template_embeds
[
"angle"
]
ret
.
update
({
"template_pair_embedding"
:
t
,
"template_pair_embedding"
:
t
,
"torsion_angles_mask"
:
template_embeds
[
"torsion_mask"
],
})
}
return
ret
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
# Primary output dictionary
# Primary output dictionary
...
@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
...
@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
)
)
# Inject information from previous recycling iterations
# Inject information from previous recycling iterations
if
(
self
.
config
.
n
o_
cycle
s
>
1
):
if
(
self
.
config
.
n
um_re
cycle
>
0
):
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
# [*, N, C_m]
# [*, N, C_m]
...
@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
...
@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
):
if
(
self
.
config
.
template
.
enabled
):
template_feats
=
{
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
"template_"
in
k
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
}
template_embeds
=
self
.
embed_templates
(
template_embeds
=
self
.
embed_templates
(
template_feats
,
template_feats
,
...
@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
...
@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
)
)
# [*, S, N]
# [*, S, N]
torsion_angles_mask
=
template_
embeds
[
"
torsion_angles_mask"
]
torsion_angles_mask
=
feats
[
"
template_torsion_angles_mask"
]
msa_mask
=
torch
.
cat
(
msa_mask
=
torch
.
cat
(
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
axis
=-
2
[
feats
[
"msa_mask"
],
torsion_angles_mask
[...,
2
]],
axis
=-
2
)
)
...
@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
...
@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
"template_aatype" ([*, N_templ, N_res])
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
Template atom coordinate mask
...
@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
...
@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
self
.
_disable_activation_checkpointing
()
self
.
_disable_activation_checkpointing
()
# Main recycling loop
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
n
o_
cycle
s
):
for
cycle_no
in
range
(
self
.
config
.
n
um_re
cycle
+
1
):
# Select the features for the current recycling cycle
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
(
self
.
config
.
n
o_cycles
-
1
)
)
is_final_iter
=
(
cycle_no
==
self
.
config
.
n
um_recycle
)
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
# Sidestep AMP bug discussed in pytorch issue #65766
# Sidestep AMP bug discussed in pytorch issue #65766
if
(
is_final_iter
):
if
(
is_final_iter
):
...
...
openfold/utils/exponential_moving_average.py
View file @
d48c052c
...
@@ -29,14 +29,15 @@ class ExponentialMovingAverage:
...
@@ -29,14 +29,15 @@ class ExponentialMovingAverage:
self
.
decay
=
decay
self
.
decay
=
decay
def
_update_state_dict_
(
self
,
update
,
state_dict
):
def
_update_state_dict_
(
self
,
update
,
state_dict
):
for
k
,
v
in
update
.
items
():
with
torch
.
no_grad
():
stored
=
state_dict
[
k
]
for
k
,
v
in
update
.
items
():
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
stored
=
state_dict
[
k
]
self
.
_update_state_dict_
(
v
,
stored
)
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
else
:
self
.
_update_state_dict_
(
v
,
stored
)
diff
=
stored
-
v
else
:
diff
*=
(
1
-
self
.
decay
)
diff
=
stored
-
v
stored
-=
diff
diff
*=
(
1
-
self
.
decay
)
stored
-=
diff
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
"""
...
...
openfold/utils/feats.py
View file @
d48c052c
...
@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
...
@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return
pseudo_beta
return
pseudo_beta
def
get_chi_atom_indices
():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices
=
[]
for
residue_name
in
rc
.
restypes
:
residue_name
=
rc
.
restype_1to3
[
residue_name
]
residue_chi_angles
=
rc
.
chi_angles_atoms
[
residue_name
]
atom_indices
=
[]
for
chi_angle
in
residue_chi_angles
:
atom_indices
.
append
(
[
rc
.
atom_order
[
atom
]
for
atom
in
chi_angle
])
for
_
in
range
(
4
-
len
(
atom_indices
)):
atom_indices
.
append
([
0
,
0
,
0
,
0
])
# For chi angles not defined on the AA.
chi_atom_indices
.
append
(
atom_indices
)
chi_atom_indices
.
append
([[
0
,
0
,
0
,
0
]]
*
4
)
# For UNKNOWN residue.
return
chi_atom_indices
def
atom14_to_atom37
(
atom14
,
batch
):
def
atom14_to_atom37
(
atom14
,
batch
):
atom37_data
=
batched_gather
(
atom37_data
=
batched_gather
(
atom14
,
atom14
,
...
@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
...
@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
return
atom37_data
return
atom37_data
def
atom37_to_torsion_angles
(
def
build_template_angle_feat
(
template_feats
):
aatype
:
torch
.
Tensor
,
template_aatype
=
template_feats
[
"template_aatype"
]
all_atom_positions
:
torch
.
Tensor
,
torsion_angles_sin_cos
=
template_feats
[
"template_torsion_angles_sin_cos"
]
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
Dictionary of the following features:
"torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
pad
=
all_atom_positions
.
new_zeros
(
[
*
all_atom_positions
.
shape
[:
-
3
],
1
,
37
,
3
]
)
prev_all_atom_positions
=
torch
.
cat
(
[
pad
,
all_atom_positions
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
1
:
3
,
:],
all_atom_positions
[...,
:
2
,
:]
],
dim
=-
2
)
phi_atom_pos
=
torch
.
cat
(
[
prev_all_atom_positions
[...,
2
:
3
,
:],
all_atom_positions
[...,
:
3
,
:]
],
dim
=-
2
)
psi_atom_pos
=
torch
.
cat
(
[
all_atom_positions
[...,
:
3
,
:],
all_atom_positions
[...,
4
:
5
,
:]
],
dim
=-
2
)
pre_omega_mask
=
(
torch
.
prod
(
prev_all_atom_mask
[...,
1
:
3
],
dim
=-
1
)
*
torch
.
prod
(
all_atom_mask
[...,
:
2
],
dim
=-
1
)
)
phi_mask
=
(
prev_all_atom_mask
[...,
2
]
*
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
)
psi_mask
=
(
torch
.
prod
(
all_atom_mask
[...,
:
3
],
dim
=-
1
,
dtype
=
all_atom_mask
.
dtype
)
*
all_atom_mask
[...,
4
]
)
chi_atom_indices
=
torch
.
as_tensor
(
get_chi_atom_indices
(),
device
=
aatype
.
device
)
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
chis_atom_pos
=
batched_gather
(
all_atom_positions
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
.
append
([
0.
,
0.
,
0.
,
0.
])
chi_angles_mask
=
all_atom_mask
.
new_tensor
(
chi_angles_mask
)
chis_mask
=
chi_angles_mask
[
aatype
,
:]
chi_angle_atoms_mask
=
batched_gather
(
all_atom_mask
,
atom_indices
,
dim
=-
1
,
no_batch_dims
=
len
(
atom_indices
.
shape
[:
-
2
])
)
chi_angle_atoms_mask
=
torch
.
prod
(
chi_angle_atoms_mask
,
dim
=-
1
,
dtype
=
chi_angle_atoms_mask
.
dtype
)
chis_mask
=
chis_mask
*
chi_angle_atoms_mask
torsions_atom_pos
=
torch
.
cat
(
[
pre_omega_atom_pos
[...,
None
,
:,
:],
phi_atom_pos
[...,
None
,
:,
:],
psi_atom_pos
[...,
None
,
:,
:],
chis_atom_pos
,
],
dim
=-
3
)
torsion_angles_mask
=
torch
.
cat
(
[
pre_omega_mask
[...,
None
],
phi_mask
[...,
None
],
psi_mask
[...,
None
],
chis_mask
,
],
dim
=-
1
)
torsion_frames
=
T
.
from_3_points
(
torsions_atom_pos
[...,
1
,
:],
torsions_atom_pos
[...,
2
,
:],
torsions_atom_pos
[...,
0
,
:],
eps
=
eps
,
)
fourth_atom_rel_pos
=
torsion_frames
.
invert
().
apply
(
torsions_atom_pos
[...,
3
,
:]
)
torsion_angles_sin_cos
=
torch
.
stack
(
[
fourth_atom_rel_pos
[...,
2
],
fourth_atom_rel_pos
[...,
1
]],
dim
=-
1
)
denom
=
torch
.
sqrt
(
torch
.
sum
(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
dtype
=
torsion_angles_sin_cos
.
dtype
,
keepdims
=
True
)
+
eps
)
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
*
all_atom_mask
.
new_tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
)[((
None
,)
*
len
(
torsion_angles_sin_cos
.
shape
[:
-
2
]))
+
(
slice
(
None
),
None
)]
chi_is_ambiguous
=
torsion_angles_sin_cos
.
new_tensor
(
rc
.
chi_pi_periodic
,
)[
aatype
,
...]
mirror_torsion_angles
=
torch
.
cat
(
[
all_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
],
dim
=-
1
)
alt_torsion_angles_sin_cos
=
(
alt_torsion_angles_sin_cos
=
(
torsion_angles_sin_cos
*
mirror_torsion_angles
[...,
None
]
template_feats
[
"template_alt_torsion_angles_sin_cos"
]
)
return
{
"torsion_angles_sin_cos"
:
torsion_angles_sin_cos
,
"alt_torsion_angles_sin_cos"
:
alt_torsion_angles_sin_cos
,
"torsion_angles_mask"
:
torsion_angles_mask
,
}
def
atom37_to_frames
(
aatype
:
torch
.
Tensor
,
all_atom_positions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
batch_dims
=
len
(
aatype
.
shape
[:
-
1
])
restype_rigidgroup_base_atom_names
=
np
.
full
([
21
,
8
,
3
],
''
,
dtype
=
object
)
restype_rigidgroup_base_atom_names
[:,
0
,
:]
=
[
'C'
,
'CA'
,
'N'
]
restype_rigidgroup_base_atom_names
[:,
3
,
:]
=
[
'CA'
,
'C'
,
'O'
]
for
restype
,
restype_letter
in
enumerate
(
rc
.
restypes
):
resname
=
rc
.
restype_1to3
[
restype_letter
]
for
chi_idx
in
range
(
4
):
if
(
rc
.
chi_angles_mask
[
restype
][
chi_idx
]):
names
=
rc
.
chi_angles_atoms
[
resname
][
chi_idx
]
restype_rigidgroup_base_atom_names
[
restype
,
chi_idx
+
4
,
:
]
=
names
[
1
:]
restype_rigidgroup_mask
=
all_atom_mask
.
new_zeros
(
(
*
aatype
.
shape
[:
-
1
],
21
,
8
),
)
restype_rigidgroup_mask
[...,
0
]
=
1
restype_rigidgroup_mask
[...,
3
]
=
1
restype_rigidgroup_mask
[...,
:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
lookuptable
=
rc
.
atom_order
.
copy
()
lookuptable
[
''
]
=
0
lookup
=
np
.
vectorize
(
lambda
x
:
lookuptable
[
x
])
restype_rigidgroup_base_atom37_idx
=
lookup
(
restype_rigidgroup_base_atom_names
,
)
restype_rigidgroup_base_atom37_idx
=
aatype
.
new_tensor
(
restype_rigidgroup_base_atom37_idx
,
)
restype_rigidgroup_base_atom37_idx
=
(
restype_rigidgroup_base_atom37_idx
.
view
(
*
((
1
,)
*
batch_dims
),
*
restype_rigidgroup_base_atom37_idx
.
shape
)
)
residx_rigidgroup_base_atom37_idx
=
batched_gather
(
restype_rigidgroup_base_atom37_idx
,
aatype
,
dim
=-
3
,
no_batch_dims
=
batch_dims
,
)
)
torsion_angles_mask
=
template_feats
[
"template_torsion_angles_mask"
]
base_atom_pos
=
batched_gather
(
all_atom_positions
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
2
,
no_batch_dims
=
len
(
all_atom_positions
.
shape
[:
-
2
]),
)
gt_frames
=
T
.
from_3_points
(
p_neg_x_axis
=
base_atom_pos
[...,
0
,
:],
origin
=
base_atom_pos
[...,
1
,
:],
p_xy_plane
=
base_atom_pos
[...,
2
,
:],
eps
=
eps
,
)
group_exists
=
batched_gather
(
restype_rigidgroup_mask
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
gt_atoms_exist
=
batched_gather
(
all_atom_mask
,
residx_rigidgroup_base_atom37_idx
,
dim
=-
1
,
no_batch_dims
=
len
(
all_atom_mask
.
shape
[:
-
1
])
)
gt_exists
=
torch
.
min
(
gt_atoms_exist
,
dim
=-
1
)[
0
]
*
group_exists
rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
rots
=
torch
.
tile
(
rots
,
(
*
((
1
,)
*
batch_dims
),
8
,
1
,
1
))
rots
[...,
0
,
0
,
0
]
=
-
1
rots
[...,
0
,
2
,
2
]
=
-
1
gt_frames
=
gt_frames
.
compose
(
T
(
rots
,
None
))
restype_rigidgroup_is_ambiguous
=
all_atom_mask
.
new_zeros
(
*
((
1
,)
*
batch_dims
),
21
,
8
)
restype_rigidgroup_rots
=
torch
.
eye
(
3
,
dtype
=
all_atom_mask
.
dtype
,
device
=
aatype
.
device
)
restype_rigidgroup_rots
=
torch
.
tile
(
restype_rigidgroup_rots
,
(
*
((
1
,)
*
batch_dims
),
21
,
8
,
1
,
1
),
)
for
resname
,
_
in
rc
.
residue_atom_renaming_swaps
.
items
():
restype
=
rc
.
restype_order
[
rc
.
restype_3to1
[
resname
]
]
chi_idx
=
int
(
sum
(
rc
.
chi_angles_mask
[
restype
])
-
1
)
restype_rigidgroup_is_ambiguous
[...,
restype
,
chi_idx
+
4
]
=
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
1
,
1
]
=
-
1
restype_rigidgroup_rots
[...,
restype
,
chi_idx
+
4
,
2
,
2
]
=
-
1
residx_rigidgroup_is_ambiguous
=
batched_gather
(
restype_rigidgroup_is_ambiguous
,
aatype
,
dim
=-
2
,
no_batch_dims
=
batch_dims
,
)
residx_rigidgroup_ambiguity_rot
=
batched_gather
(
restype_rigidgroup_rots
,
aatype
,
dim
=-
4
,
no_batch_dims
=
batch_dims
,
)
alt_gt_frames
=
gt_frames
.
compose
(
T
(
residx_rigidgroup_ambiguity_rot
,
None
))
gt_frames_tensor
=
gt_frames
.
to_4x4
()
alt_gt_frames_tensor
=
alt_gt_frames
.
to_4x4
()
return
{
'rigidgroups_gt_frames'
:
gt_frames_tensor
,
'rigidgroups_gt_exists'
:
gt_exists
,
'rigidgroups_group_exists'
:
group_exists
,
'rigidgroups_group_is_ambiguous'
:
residx_rigidgroup_is_ambiguous
,
'rigidgroups_alt_gt_frames'
:
alt_gt_frames_tensor
,
}
def
build_template_angle_feat
(
angle_feats
,
template_aatype
):
torsion_angles_sin_cos
=
angle_feats
[
"torsion_angles_sin_cos"
]
alt_torsion_angles_sin_cos
=
angle_feats
[
"alt_torsion_angles_sin_cos"
]
torsion_angles_mask
=
angle_feats
[
"torsion_angles_mask"
]
template_angle_feat
=
torch
.
cat
(
template_angle_feat
=
torch
.
cat
(
[
[
nn
.
functional
.
one_hot
(
template_aatype
,
22
),
nn
.
functional
.
one_hot
(
template_aatype
,
22
),
...
@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
...
@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
eps
+
torch
.
sum
(
affine_vec
**
2
,
dim
=-
1
)
)
)
t_aa_masks
=
batch
[
"template_all_atom_mask
s
"
]
t_aa_masks
=
batch
[
"template_all_atom_mask"
]
template_mask
=
(
template_mask
=
(
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
t_aa_masks
[...,
n
]
*
t_aa_masks
[...,
ca
]
*
t_aa_masks
[...,
c
]
)
)
...
@@ -534,53 +201,6 @@ def build_msa_feat(batch):
...
@@ -534,53 +201,6 @@ def build_msa_feat(batch):
return
batch
return
batch
def
build_ambiguity_feats
(
batch
:
Dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms
=
(
batch
[
"atom14_gt_positions"
].
new_tensor
(
rc
.
restype_atom14_ambiguous_atoms
)
)
atom14_atom_is_ambiguous
=
ambiguous_atoms
[
batch
[
"aatype"
],
...]
# Swap pairs of ambiguous positions
swap_idx
=
rc
.
restype_atom14_ambiguous_atoms_swap_idx
swap_mat
=
np
.
eye
(
swap_idx
.
shape
[
-
1
])[
swap_idx
]
# one-hot swap_idx
swap_mat
=
batch
[
"atom14_gt_positions"
].
new_tensor
(
swap_mat
)
swap_mat
=
swap_mat
[
batch
[
"aatype"
],
...]
atom14_alt_gt_positions
=
(
torch
.
sum
(
batch
[
"atom14_gt_positions"
][...,
None
,
:]
*
swap_mat
[...,
None
],
dim
=-
3
)
)
atom14_alt_gt_exists
=
(
torch
.
sum
(
batch
[
"atom14_gt_exists"
][...,
None
]
*
swap_mat
,
dim
=-
2
)
)
return
{
"atom14_atom_is_ambiguous"
:
atom14_atom_is_ambiguous
,
"atom14_alt_gt_positions"
:
atom14_alt_gt_positions
,
"atom14_alt_gt_exists"
:
atom14_alt_gt_exists
,
}
def
torsion_angles_to_frames
(
def
torsion_angles_to_frames
(
t
:
T
,
t
:
T
,
alpha
:
torch
.
Tensor
,
alpha
:
torch
.
Tensor
,
...
...
openfold/utils/loss.py
View file @
d48c052c
...
@@ -18,6 +18,7 @@ import ml_collections
...
@@ -18,6 +18,7 @@ import ml_collections
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.distributions.bernoulli
import
Bernoulli
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Dict
,
Optional
,
Tuple
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
...
@@ -117,7 +118,9 @@ def compute_fape(
...
@@ -117,7 +118,9 @@ def compute_fape(
return
normed_error
return
normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
# DISCREPANCY: From the way this function is written, it's possible that
# DeepMind clamped 90% of individual residue losses, not 90% of all batches.
# We defer to the text, which seems to imply the latter.
def
backbone_loss
(
def
backbone_loss
(
backbone_affine_tensor
:
torch
.
Tensor
,
backbone_affine_tensor
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
backbone_affine_mask
:
torch
.
Tensor
,
...
@@ -130,7 +133,7 @@ def backbone_loss(
...
@@ -130,7 +133,7 @@ def backbone_loss(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
pred_aff
=
T
.
from_tensor
(
traj
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
gt_aff
=
T
.
from_tensor
(
backbone_affine_tensor
)
fape_loss
=
compute_fape
(
fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
gt_aff
[...,
None
,
:],
gt_aff
[...,
None
,
:],
...
@@ -142,7 +145,6 @@ def backbone_loss(
...
@@ -142,7 +145,6 @@ def backbone_loss(
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
eps
=
eps
,
)
)
if
(
use_clamped_fape
is
not
None
):
if
(
use_clamped_fape
is
not
None
):
unclamped_fape_loss
=
compute_fape
(
unclamped_fape_loss
=
compute_fape
(
pred_aff
,
pred_aff
,
...
@@ -157,12 +159,12 @@ def backbone_loss(
...
@@ -157,12 +159,12 @@ def backbone_loss(
)
)
fape_loss
=
(
fape_loss
=
(
fape_loss
*
use_clamped_fape
+
fape_loss
*
use_clamped_fape
+
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
)
# Take the mean over the layer dimension
# Take the mean over the layer dimension
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
0
)
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
-
1
)
return
fape_loss
return
fape_loss
...
@@ -1453,7 +1455,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1453,7 +1455,7 @@ class AlphaFoldLoss(nn.Module):
super
(
AlphaFoldLoss
,
self
).
__init__
()
super
(
AlphaFoldLoss
,
self
).
__init__
()
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
out
[
"violation"
]
=
find_structural_violations
(
out
[
"violation"
]
=
find_structural_violations
(
batch
,
batch
,
...
@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
**
self
.
config
.
violation
,
**
self
.
config
.
violation
,
)
)
if
(
"atom14_atom_is_ambiguous"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
build_ambiguity_feats
(
batch
))
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
out
[
"sm"
][
"positions"
][
-
1
],
))
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
eps
=
self
.
config
.
eps
,
**
batch
))
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
)
batch
[
"backbone_affine_mask"
]
=
(
batch
[
"rigidgroups_gt_exists"
][...,
0
]
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
with
torch
.
no_grad
():
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
aatype
=
batch
[
"aatype"
],
all_atom_positions
=
batch
[
"all_atom_positions"
].
double
(),
all_atom_mask
=
batch
[
"all_atom_mask"
].
double
(),
eps
=
self
.
config
.
eps
,
))
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
).
to
(
batch
[
"all_atom_mask"
].
dtype
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:].
to
(
batch
[
"all_atom_mask"
].
dtype
)
loss_fns
=
{
loss_fns
=
{
"distogram"
:
"distogram"
:
lambda
:
distogram_loss
(
lambda
:
distogram_loss
(
...
...
run_pretrained_openfold.py
View file @
d48c052c
...
@@ -15,17 +15,17 @@
...
@@ -15,17 +15,17 @@
import
argparse
import
argparse
from
datetime
import
date
from
datetime
import
date
import
pickle
import
logging
import
os
import
os
# A hack to get OpenMM and PyTorch to peacefully coexist
# A hack to get OpenMM and PyTorch to peacefully coexist
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
os
.
environ
[
"OPENMM_DEFAULT_PLATFORM"
]
=
"OpenCL"
import
pickle
import
random
import
random
import
sys
import
sys
from
openfold.features
import
templates
,
feature_pipeline
from
openfold.features
import
templates
,
feature_pipeline
,
data_pipeline
from
openfold.features.np
import
data_pipeline
import
time
import
time
...
@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
...
@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
tensor_tree_map
,
tensor_tree_map
,
)
)
MAX_TEMPLATE_HITS
=
20
from
scripts.utils
import
add_data_args
def
main
(
args
):
def
main
(
args
):
config
=
model_config
(
args
.
model_name
)
config
=
model_config
(
args
.
model_name
)
model
=
AlphaFold
(
config
.
model
)
model
=
AlphaFold
(
config
.
model
)
model
=
model
.
eval
()
model
=
model
.
eval
()
import_jax_weights_
(
model
,
args
.
param_path
)
import_jax_weights_
(
model
,
args
.
param_path
)
model
=
model
.
to
(
args
.
device
)
model
=
model
.
to
(
args
.
model_
device
)
# FEATURE COLLECTION AND PROCESSING
# FEATURE COLLECTION AND PROCESSING
use_small_bfd
=
args
.
preset
==
"reduced_dbs"
num_ensemble
=
1
num_ensemble
=
1
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
args
.
template_mmcif_dir
,
mmcif_dir
=
args
.
template_mmcif_dir
,
max_template_date
=
args
.
max_template_date
,
max_template_date
=
args
.
max_template_date
,
max_hits
=
MAX_TEMPLATE_HITS
,
max_hits
=
args
.
max_template_hits
,
kalign_binary_path
=
args
.
kalign_binary_path
,
kalign_binary_path
=
args
.
kalign_binary_path
,
release_dates_path
=
None
,
release_dates_path
=
None
,
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
obsolete_pdbs_path
=
args
.
obsolete_pdbs_path
)
)
use_small_bfd
=
(
args
.
bfd_database_path
is
None
)
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
...
@@ -76,6 +77,7 @@ def main(args):
...
@@ -76,6 +77,7 @@ def main(args):
small_bfd_database_path
=
args
.
small_bfd_database_path
,
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
use_small_bfd
,
use_small_bfd
=
use_small_bfd
,
no_cpus
=
args
.
cpus
,
)
)
data_processor
=
data_pipeline
.
DataPipeline
(
data_processor
=
data_pipeline
.
DataPipeline
(
...
@@ -87,7 +89,7 @@ def main(args):
...
@@ -87,7 +89,7 @@ def main(args):
random_seed
=
args
.
data_random_seed
random_seed
=
args
.
data_random_seed
if
random_seed
is
None
:
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
config
.
data
.
eval
.
num_ensemble
=
num_ensemble
config
.
data
.
predict
.
num_ensemble
=
num_ensemble
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
...
@@ -95,7 +97,7 @@ def main(args):
...
@@ -95,7 +97,7 @@ def main(args):
if
not
os
.
path
.
exists
(
alignment_dir
):
if
not
os
.
path
.
exists
(
alignment_dir
):
os
.
makedirs
(
alignment_dir
)
os
.
makedirs
(
alignment_dir
)
print
(
"Generating features..."
)
logging
.
info
(
"Generating features..."
)
alignment_runner
.
run
(
alignment_runner
.
run
(
args
.
fasta_path
,
alignment_dir
args
.
fasta_path
,
alignment_dir
)
)
...
@@ -105,42 +107,20 @@ def main(args):
...
@@ -105,42 +107,20 @@ def main(args):
)
)
processed_feature_dict
=
feature_processor
.
process_features
(
processed_feature_dict
=
feature_processor
.
process_features
(
feature_dict
,
random_seed
feature_dict
,
mode
=
'predict'
,
)
)
for
k
,
v
in
processed_feature_dict
.
items
():
logging
.
info
(
"Executing model..."
)
print
(
k
)
print
(
v
.
shape
)
print
(
"Executing model..."
)
batch
=
processed_feature_dict
batch
=
processed_feature_dict
with
torch
.
no_grad
():
with
torch
.
no_grad
():
batch
=
{
batch
=
{
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
device
)
k
:
torch
.
as_tensor
(
v
,
device
=
args
.
model_
device
)
for
k
,
v
in
batch
.
items
()
for
k
,
v
in
batch
.
items
()
}
}
longs
=
[
"aatype"
,
"template_aatype"
,
"extra_msa"
,
"residx_atom37_to_atom14"
,
"residx_atom14_to_atom37"
,
"true_msa"
,
"residue_index"
,
]
for
l
in
longs
:
batch
[
l
]
=
batch
[
l
].
long
()
# Move the recycling dimension to the end
move_dim
=
lambda
t
:
t
.
permute
(
*
range
(
len
(
t
.
shape
))[
1
:],
0
)
batch
=
tensor_tree_map
(
move_dim
,
batch
)
make_contig
=
lambda
t
:
t
.
contiguous
()
batch
=
tensor_tree_map
(
make_contig
,
batch
)
t
=
time
.
time
()
t
=
time
.
time
()
out
=
model
(
batch
)
out
=
model
(
batch
)
print
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
logging
.
info
(
f
"Inference time:
{
time
.
time
()
-
t
}
"
)
# Toss out the recycling dimensions --- we don't need them anymore
# Toss out the recycling dimensions --- we don't need them anymore
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
batch
=
tensor_tree_map
(
lambda
x
:
np
.
array
(
x
[...,
-
1
].
cpu
()),
batch
)
...
@@ -158,9 +138,7 @@ def main(args):
...
@@ -158,9 +138,7 @@ def main(args):
result
=
out
,
result
=
out
,
b_factors
=
plddt_b_factors
b_factors
=
plddt_b_factors
)
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"7"
amber_relaxer
=
relax
.
AmberRelaxation
(
amber_relaxer
=
relax
.
AmberRelaxation
(
**
config
.
relax
**
config
.
relax
)
)
...
@@ -168,7 +146,7 @@ def main(args):
...
@@ -168,7 +146,7 @@ def main(args):
# Relax the prediction.
# Relax the prediction.
t
=
time
.
time
()
t
=
time
.
time
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
logging
.
info
(
f
"Relaxation time:
{
time
.
time
()
-
t
}
"
)
# Save the relaxed PDB.
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
relaxed_output_path
=
os
.
path
.
join
(
...
@@ -183,53 +161,14 @@ if __name__ == "__main__":
...
@@ -183,53 +161,14 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"fasta_path"
,
type
=
str
,
"fasta_path"
,
type
=
str
,
)
)
parser
.
add_argument
(
add_data_args
(
parser
)
'uniref90_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
)
parser
.
add_argument
(
'--uniclust30_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
"--output_dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"""Name of the directory in which to output the prediction"""
,
help
=
"""Name of the directory in which to output the prediction"""
,
required
=
True
required
=
True
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cpu"
,
"--
model_
device"
,
type
=
str
,
default
=
"cpu"
,
help
=
"""Name of the device on which to run the model. Any valid torch
help
=
"""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
device name is accepted (e.g. "cpu", "cuda:0")"""
)
)
...
@@ -244,6 +183,10 @@ if __name__ == "__main__":
...
@@ -244,6 +183,10 @@ if __name__ == "__main__":
automatically according to the model name from
automatically according to the model name from
openfold/resources/params"""
openfold/resources/params"""
)
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"""Number of CPUs to use to run alignment tools"""
)
parser
.
add_argument
(
parser
.
add_argument
(
'--preset'
,
type
=
str
,
default
=
'full_dbs'
,
'--preset'
,
type
=
str
,
default
=
'full_dbs'
,
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
choices
=
(
'reduced_dbs'
,
'full_dbs'
)
...
...
scripts/build_deepspeed_config.py
View file @
d48c052c
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
argparse
import
argparse
import
json
import
json
parser
=
argparse
.
ArgumentParser
(
description
=
'''Outputs a DeepSpeed
parser
=
argparse
.
ArgumentParser
(
description
=
'''Outputs a DeepSpeed
configuration file to
configuration file to
stdout'''
)
stdout'''
)
...
...
scripts/precompute_alignments.py
0 → 100644
View file @
d48c052c
import
argparse
import
logging
import
os
import
tempfile
import
openfold.features.mmcif_parsing
as
mmcif_parsing
from
openfold.features.data_pipeline
import
AlignmentRunner
from
scripts.utils
import
add_data_args
def
main
(
args
):
# Build the alignment tool runner
alignment_runner
=
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
hhblits_binary_path
=
args
.
hhblits_binary_path
,
hhsearch_binary_path
=
args
.
hhsearch_binary_path
,
uniref90_database_path
=
args
.
uniref90_database_path
,
mgnify_database_path
=
args
.
mgnify_database_path
,
bfd_database_path
=
args
.
bfd_database_path
,
uniclust30_database_path
=
args
.
uniclust30_database_path
,
small_bfd_database_path
=
args
.
small_bfd_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
use_small_bfd
=
args
.
bfd_database_path
is
None
,
no_cpus
=
args
.
cpus
,
)
for
f
in
os
.
listdir
(
args
.
input_dir
):
path
=
os
.
path
.
join
(
args
.
input_dir
,
f
)
is_mmcif
=
f
.
endswith
(
'.cif'
)
is_fasta
=
f
.
endswith
(
'.fasta'
)
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
seqs
=
{}
if
(
is_mmcif
):
with
open
(
path
,
'r'
)
as
fp
:
mmcif_str
=
fp
.
read
()
mmcif
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_str
)
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
f
}
...'
)
if
(
args
.
raise_errors
):
raise
list
(
mmcif
.
errors
.
values
())[
0
]
else
:
continue
mmcif
=
mmcif
.
mmcif_object
for
k
,
v
in
mmcif
.
chain_to_seqres
.
items
():
chain_id
=
'_'
.
join
([
file_id
,
k
])
seqs
[
chain_id
]
=
v
elif
(
is_fasta
):
with
open
(
path
,
'r'
)
as
fp
:
fasta_str
=
fp
.
read
()
input_seqs
,
_
=
parsers
.
parse_fasta
(
fasta_str
)
if
len
(
input_seqs
)
!=
1
:
msg
=
f
'More than one input_sequence found in
{
f
}
'
if
(
args
.
raise_errors
):
raise
ValueError
(
msg
)
else
:
logging
.
warning
(
msg
)
input_sequence
=
input_seqs
[
0
]
seqs
[
file_id
]
=
input_sequence
else
:
continue
for
name
,
seq
in
seqs
.
items
():
alignment_dir
=
os
.
path
.
join
(
args
.
output_dir
,
name
)
if
(
os
.
path
.
isdir
(
alignment_dir
)):
logging
.
info
(
f
'
{
f
}
has already been processed. Skipping...'
)
continue
os
.
makedirs
(
alignment_dir
)
if
(
not
is_fasta
):
fd
,
fasta_path
=
tempfile
.
mkstemp
(
suffix
=
".fasta"
)
with
os
.
fdopen
(
fd
,
'w'
)
as
fp
:
fp
.
write
(
f
'>query
\n
{
seq
}
'
)
alignment_runner
.
run
(
f
if
is_fasta
else
fasta_path
,
alignment_dir
)
if
(
not
is_fasta
):
os
.
remove
(
fasta_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"input_dir"
,
type
=
str
,
help
=
"Path to directory containing mmCIF and/or FASTA files"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
"Directory in which to output alignments"
)
add_data_args
(
parser
)
parser
.
add_argument
(
"--raise_errors"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to crash on parsing errors"
)
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"Number of CPUs to use"
)
args
=
parser
.
parse_args
()
main
(
args
)
scripts/utils.py
0 → 100644
View file @
d48c052c
import
argparse
from
datetime
import
date
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
)
parser
.
add_argument
(
'uniclust30_database_path'
,
type
=
str
,
)
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
'--small_bfd_database_path'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--jackhmmer_binary_path'
,
type
=
str
,
default
=
'/usr/bin/jackhmmer'
)
parser
.
add_argument
(
'--hhblits_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhblits'
)
parser
.
add_argument
(
'--hhsearch_binary_path'
,
type
=
str
,
default
=
'/usr/bin/hhsearch'
)
parser
.
add_argument
(
'--kalign_binary_path'
,
type
=
str
,
default
=
'/usr/bin/kalign'
)
parser
.
add_argument
(
'--max_template_date'
,
type
=
str
,
default
=
date
.
today
().
strftime
(
"%Y-%m-%d"
),
)
parser
.
add_argument
(
'--max_template_hits'
,
type
=
int
,
default
=
20
,
)
parser
.
add_argument
(
'--obsolete_pdbs_path'
,
type
=
str
,
default
=
None
)
tests/compare_utils.py
View file @
d48c052c
import
os
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"4,"
import
importlib
import
importlib
import
pkgutil
import
pkgutil
import
sys
import
sys
...
...
tests/test_feats.py
View file @
d48c052c
...
@@ -16,6 +16,7 @@ import torch
...
@@ -16,6 +16,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
import
openfold.features.data_transforms
as
data_transforms
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_to_rigid_group
,
...
@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
...
@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
to_tensor
=
lambda
t
:
torch
.
tensor
(
np
.
array
(
t
)).
cuda
()
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
batch
=
tree_map
(
to_tensor
,
batch
,
np
.
ndarray
)
out_repro
=
feat
s
.
atom37_to_frames
(
eps
=
1e-8
,
**
batch
)
out_repro
=
data_transform
s
.
atom37_to_frames
(
batch
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
out_repro
=
tensor_tree_map
(
lambda
t
:
t
.
cpu
(),
out_repro
)
for
k
,
v
in
out_gt
.
items
():
for
k
,
v
in
out_gt
.
items
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment