Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a59ae7c1
Commit
a59ae7c1
authored
Oct 17, 2021
by
Gustaf Ahdritz
Browse files
Refactor data pipeline; add distillation parsing
parent
07e64267
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
532 additions
and
352 deletions
+532
-352
openfold/__init__.py
openfold/__init__.py
+1
-0
openfold/config.py
openfold/config.py
+3
-0
openfold/data/data_modules.py
openfold/data/data_modules.py
+403
-0
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+87
-6
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+6
-8
openfold/data/templates.py
openfold/data/templates.py
+1
-1
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+2
-2
openfold/utils/loss.py
openfold/utils/loss.py
+0
-3
run_pretrained_openfold.py
run_pretrained_openfold.py
+1
-1
train_openfold.py
train_openfold.py
+28
-331
No files found.
openfold/__init__.py
View file @
a59ae7c1
from
.
import
model
from
.
import
model
from
.
import
utils
from
.
import
utils
from
.
import
np
from
.
import
np
__all__
=
[
"model"
,
"utils"
,
"np"
]
__all__
=
[
"model"
,
"utils"
,
"np"
]
openfold/config.py
View file @
a59ae7c1
...
@@ -183,6 +183,7 @@ config = mlc.ConfigDict(
...
@@ -183,6 +183,7 @@ config = mlc.ConfigDict(
"subsample_templates"
:
False
,
# We want top templates.
"subsample_templates"
:
False
,
# We want top templates.
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
512
,
"max_msa_clusters"
:
512
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"num_ensemble"
:
1
,
"num_ensemble"
:
1
,
"crop"
:
False
,
"crop"
:
False
,
...
@@ -194,6 +195,7 @@ config = mlc.ConfigDict(
...
@@ -194,6 +195,7 @@ config = mlc.ConfigDict(
"subsample_templates"
:
False
,
# We want top templates.
"subsample_templates"
:
False
,
# We want top templates.
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
512
,
"max_msa_clusters"
:
512
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"num_ensemble"
:
1
,
"num_ensemble"
:
1
,
"crop"
:
False
,
"crop"
:
False
,
...
@@ -205,6 +207,7 @@ config = mlc.ConfigDict(
...
@@ -205,6 +207,7 @@ config = mlc.ConfigDict(
"subsample_templates"
:
True
,
"subsample_templates"
:
True
,
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
512
,
"max_msa_clusters"
:
512
,
"max_template_hits"
:
20
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"num_ensemble"
:
1
,
"num_ensemble"
:
1
,
"crop"
:
True
,
"crop"
:
True
,
...
...
openfold/data/data_modules.py
0 → 100644
View file @
a59ae7c1
from
functools
import
partial
import
json
import
logging
import
os
from
typing
import
Optional
,
Sequence
import
ml_collections
as
mlc
import
pytorch_lightning
as
pl
import
torch
from
torch.utils.data
import
RandomSampler
from
openfold.data
import
(
data_pipeline
,
feature_pipeline
,
mmcif_parsing
,
templates
,
)
from
openfold.utils.tensor_utils
import
tensor_tree_map
,
dict_multimap
class
OpenFoldSingleDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
mapping_path
:
Optional
[
str
]
=
None
,
max_template_hits
:
int
=
4
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
use_small_bfd
:
bool
=
True
,
output_raw
:
bool
=
False
,
mode
:
str
=
"train"
,
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config:
A dataset config object. See openfold.config
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement
"""
super
(
OpenFoldSingleDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
output_raw
=
output_raw
self
.
mode
=
mode
valid_modes
=
[
"train"
,
"val"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
if
(
mapping_path
is
None
):
self
.
mapping
=
{
str
(
i
):
os
.
path
.
splitext
(
name
)[
0
]
for
i
,
name
in
enumerate
(
os
.
listdir
(
alignment_dir
))
}
else
:
with
open
(
mapping_path
,
'r'
)
as
fp
:
self
.
mapping
=
json
.
load
(
fp
)
if
(
template_release_dates_cache_path
is
None
):
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_caches.py before running OpenFold"
)
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
max_template_date
,
max_hits
=
max_template_hits
,
kalign_binary_path
=
kalign_binary_path
,
release_dates_path
=
template_release_dates_cache_path
,
obsolete_pdbs_path
=
None
,
)
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
use_small_bfd
=
use_small_bfd
,
)
if
(
not
self
.
output_raw
):
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
_parse_mmcif
(
self
,
path
,
file_id
,
chain_id
,
alignment_dir
):
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
mmcif_object
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if
(
mmcif_object
.
mmcif_object
is
None
):
raise
list
(
mmcif_object
.
errors
.
values
())[
0
]
mmcif_object
=
mmcif_object
.
mmcif_object
data
=
self
.
data_pipeline
.
process_mmcif
(
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
chain_id
=
chain_id
,
)
return
data
def
__getitem__
(
self
,
idx
):
name
=
self
.
mapping
[
str
(
idx
)]
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'val'
):
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
):
file_id
,
chain_id
=
spl
else
:
file_id
,
=
spl
chain_id
=
None
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
+
'.cif'
)
if
(
os
.
path
.
exists
(
path
)):
data
=
self
.
_parse_mmcif
(
path
,
file_id
,
chain_id
,
alignment_dir
)
else
:
# Try to search for a distillation PDB file instead
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
+
'.pdb'
)
data
=
self
.
data_pipeline
.
process_pdb
(
pdb_path
=
path
,
alignment_dir
=
alignment_dir
)
else
:
path
=
os
.
path
.
join
(
name
,
name
+
".fasta"
)
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
feats
,
alignment_dir
=
alignment_dir
,
)
if
(
self
.
output_raw
):
return
data
feats
=
self
.
feature_pipeline
.
process_features
(
data
,
self
.
mode
,
"unclamped"
)
return
feats
def
__len__
(
self
):
return
len
(
self
.
mapping
.
keys
())
def
looped_sequence
(
sequence
):
while
True
:
for
x
in
sequence
:
yield
x
class
OpenFoldDataset
(
torch
.
utils
.
data
.
IterableDataset
):
"""
The Dataset is written to accommodate the requirement that proteins are
sampled from the distillation set with some probability p
and from the PDB set with probability (1 - p). Proteins are sampled
from both sets without replacement, and as soon as either set is
emptied, it is refilled. The Dataset therefore has an arbitrary length.
Nevertheless, for compatibility with various PyTorch Lightning
functionalities, it is possible to specify an epoch length. This length
has no effect on the output of the Dataset.
"""
def
__init__
(
self
,
datasets
:
Sequence
[
OpenFoldSingleDataset
],
probabilities
:
Sequence
[
int
],
epoch_len
:
int
,
):
self
.
datasets
=
datasets
self
.
samplers
=
[
looped_sequence
(
RandomSampler
(
d
))
for
d
in
datasets
]
self
.
batch_size
=
batch_size
self
.
epoch_len
=
epoch_len
self
.
distr
=
torch
.
distributions
.
categorical
.
Categorical
(
probs
=
torch
.
tensor
(
probabilities
),
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
dataset_idx
=
self
.
distr
.
sample
()
sampler
=
self
.
samplers
[
dataset_idx
]
element_idx
=
next
(
sampler
)
return
self
.
datasets
[
dataset_idx
][
element_idx
]
def
__len__
(
self
):
return
self
.
epoch_len
class
OpenFoldBatchCollator
:
def
__init__
(
self
,
config
,
generator
,
stage
=
"train"
):
self
.
config
=
config
batch_modes
=
config
.
common
.
batch_modes
batch_mode_names
,
batch_mode_probs
=
list
(
zip
(
*
batch_modes
))
self
.
batch_mode_names
=
batch_mode_names
self
.
batch_mode_probs
=
batch_mode_probs
self
.
generator
=
generator
self
.
stage
=
stage
self
.
batch_mode_probs_tensor
=
torch
.
tensor
(
self
.
batch_mode_probs
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
self
.
config
)
def
__call__
(
self
,
raw_prots
):
# We use torch.multinomial here rather than Categorical because the
# latter doesn't accept a generator for some reason
batch_mode_idx
=
torch
.
multinomial
(
self
.
batch_mode_probs_tensor
,
1
,
generator
=
self
.
generator
).
item
()
batch_mode_name
=
self
.
batch_mode_names
[
batch_mode_idx
]
processed_prots
=
[]
for
prot
in
raw_prots
:
features
=
self
.
feature_pipeline
.
process_features
(
prot
,
self
.
stage
,
batch_mode_name
)
processed_prots
.
append
(
features
)
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
return
dict_multimap
(
stack_fn
,
processed_prots
)
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
distillation_data_dir
:
Optional
[
str
]
=
None
,
distillation_alignment_dir
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_mapping_path
:
Optional
[
str
]
=
None
,
distillation_mapping_path
:
Optional
[
str
]
=
None
,
template_release_dates_cache_path
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
self
.
config
=
config
self
.
template_mmcif_dir
=
template_mmcif_dir
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
distillation_data_dir
=
distillation_data_dir
self
.
distillation_alignment_dir
=
distillation_alignment_dir
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
predict_data_dir
=
predict_data_dir
self
.
predict_alignment_dir
=
predict_alignment_dir
self
.
kalign_binary_path
=
kalign_binary_path
self
.
train_mapping_path
=
train_mapping_path
self
.
distillation_mapping_path
=
distillation_mapping_path
self
.
template_release_dates_cache_path
=
(
template_release_dates_cache_path
)
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
self
.
train_alignment_dir
is
None
):
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
)
elif
(
not
self
.
training_mode
and
self
.
predict_alingment_dir
is
None
):
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
raise
ValueError
(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
def
setup
(
self
,
stage
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
template_release_dates_cache_path
=
self
.
template_release_dates_cache_path
,
use_small_bfd
=
self
.
config
.
data_module
.
use_small_bfd
,
)
if
(
self
.
training_mode
):
self
.
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
max_template_hits
=
self
.
config
.
train
.
max_template_hits
,
output_raw
=
True
,
mode
=
"train"
,
)
if
(
self
.
distillation_data_dir
is
not
None
):
distillation_dataset
=
dataset_gen
(
data_dir
=
self
.
distillation_data_dir
,
alignment_dir
=
self
.
distillation_alignment_dir
,
mapping_path
=
self
.
distillation_mapping_path
,
max_template_hits
=
self
.
train
.
max_template_hits
,
output_raw
=
True
,
mode
=
"train"
,
)
d_prob
=
self
.
config
.
train
.
distillation_prob
self
.
train_dataset
=
OpenFoldDataset
(
datasets
=
[
self
.
train_dataset
,
distillation_dataset
],
probabilities
=
[
1
-
d_prob
,
d_prob
],
epoch_len
=
(
self
.
train_dataset
.
len
()
+
distillation_dataset
.
len
()
),
)
if
(
self
.
val_data_dir
is
not
None
):
self
.
val_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
eval
.
max_template_hits
,
mode
=
"eval"
,
)
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
mapping_path
=
None
,
max_template_hits
=
self
.
config
.
predict
.
max_template_hits
,
mode
=
"predict"
,
)
self
.
batch_collation_seed
=
torch
.
Generator
().
seed
()
def
_gen_batch_collator
(
self
,
stage
):
""" We want each process to use the same batch collation seed """
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
self
.
batch_collation_seed
)
collate_fn
=
OpenFoldBatchCollator
(
self
.
config
,
generator
,
stage
)
return
collate_fn
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
train_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"train"
),
)
def
val_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
val_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
def
predict_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
predict_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
self
.
_gen_batch_collator
(
"eval"
)
)
openfold/data/data_pipeline.py
View file @
a59ae7c1
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data
import
templates
,
parsers
,
mmcif_parsing
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools
import
jackhmmer
,
hhblits
,
hhsearch
from
openfold.data.tools.utils
import
to_date
from
openfold.data.tools.utils
import
to_date
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
,
protein
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
FeatureDict
=
Mapping
[
str
,
np
.
ndarray
]
...
@@ -81,9 +81,43 @@ def make_mmcif_features(
...
@@ -81,9 +81,43 @@ def make_mmcif_features(
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
np
.
object_
[
mmcif_object
.
header
[
"release_date"
].
encode
(
"utf-8"
)],
dtype
=
np
.
object_
)
)
mmcif_feats
[
"is_distillation"
]
=
np
.
array
(
0.
,
dtype
=
np
.
float32
)
return
mmcif_feats
return
mmcif_feats
def
make_pdb_features
(
protein_object
:
protein
.
Protein
,
description
:
str
,
confidence_threshold
:
float
=
0.5
,
)
->
FeatureDict
:
pdb_feats
=
{}
pdb_feats
.
update
(
make_sequence_features
(
sequence
=
protein_object
.
aatype
,
description
=
description
,
num_res
=
len
(
protein_object
.
aatype
),
)
)
all_atom_positions
=
protein_object
.
atom_positions
all_atom_mask
=
protein_object
.
atom_mask
high_confidence
=
protein
.
b_factors
>
confidence_threshold
high_confidence
=
np
.
any
(
high_confidence
,
axis
=-
1
)
for
i
,
confident
in
enumerate
(
high_confidence
):
if
(
not
confident
):
all_atom_mask
[
i
]
=
0
pdb_feats
[
"all_atom_positions"
]
=
all_atom_positions
pdb_feats
[
"all_atom_mask"
]
=
all_atom_mask
pdb_feats
[
"is_distillation"
]
=
np
.
array
(
1.
).
astype
(
np
.
float32
)
return
pdb_feats
def
make_msa_features
(
def
make_msa_features
(
msas
:
Sequence
[
Sequence
[
str
]],
msas
:
Sequence
[
Sequence
[
str
]],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
],
deletion_matrices
:
Sequence
[
parsers
.
DeletionMatrix
],
...
@@ -311,7 +345,11 @@ class DataPipeline:
...
@@ -311,7 +345,11 @@ class DataPipeline:
alignments
[
"mgnify_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
),
)
)
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
data
}
return
{
**
sequence_features
,
**
msa_features
,
**
templates_result
.
features
}
def
process_mmcif
(
def
process_mmcif
(
self
,
self
,
...
@@ -357,4 +395,47 @@ class DataPipeline:
...
@@ -357,4 +395,47 @@ class DataPipeline:
),
),
)
)
return
{
**
mmcif_feats
,
**
templates_result
.
data
,
**
msa_features
}
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
def
process_pdb
(
self
,
pdb_path
:
str
,
alignment_dir
:
str
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a PDB file.
"""
with
open
(
pdb_path
,
'r'
)
as
f
:
pdb_str
=
pdb_path
protein_object
=
protein
.
from_pdb_string
(
pdb_str
)
pdb_feats
=
make_pdb_features
(
protein_object
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequence
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
]),
hits
=
alignments
[
"hhsearch_hits"
],
)
msa_features
=
make_msa_features
(
msas
=
(
alignments
[
"uniref90_msa"
],
alignments
[
"bfd_msa"
],
alignments
[
"mgnify_msa"
],
),
deletion_matrices
=
(
alignments
[
"uniref90_deletion_matrix"
],
alignments
[
"bfd_deletion_matrix"
],
alignments
[
"mgnify_deletion_matrix"
],
),
)
return
{
**
mmcif_feats
,
**
templates_result
.
features
,
**
msa_features
}
openfold/data/data_transforms.py
View file @
a59ae7c1
...
@@ -21,7 +21,7 @@ import numpy as np
...
@@ -21,7 +21,7 @@ import numpy as np
import
torch
import
torch
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.config
import
NUM_RES
,
NUM_EXTRA_SEQ
,
NUM_TEMPLATES
,
NUM_MSA_SEQ
from
openfold.
tools
import
residue_constants
as
rc
from
openfold.
np
import
residue_constants
as
rc
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
...
@@ -1104,7 +1104,7 @@ def random_crop_to_size(
...
@@ -1104,7 +1104,7 @@ def random_crop_to_size(
else
:
else
:
num_templates
=
protein
[
"aatype"
].
new_zeros
((
1
,))
num_templates
=
protein
[
"aatype"
].
new_zeros
((
1
,))
num_res_crop_size
=
min
(
seq_length
,
crop_size
)
num_res_crop_size
=
min
(
seq_length
.
item
()
,
crop_size
)
# We want each ensemble to be cropped the same way
# We want each ensemble to be cropped the same way
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
g
=
torch
.
Generator
(
device
=
protein
[
"seq_length"
].
device
)
...
@@ -1112,18 +1112,16 @@ def random_crop_to_size(
...
@@ -1112,18 +1112,16 @@ def random_crop_to_size(
g
.
manual_seed
(
seed
)
g
.
manual_seed
(
seed
)
def
_randint
(
lower
,
upper
):
def
_randint
(
lower
,
upper
):
return
int
(
return
torch
.
randint
(
torch
.
randint
(
lower
,
lower
,
upper
,
upper
+
1
,
(
1
,),
(
1
,),
device
=
protein
[
"seq_length"
].
device
,
device
=
protein
[
"seq_length"
].
device
,
generator
=
g
,
generator
=
g
,
)[
0
]
)[
0
].
item
()
)
if
subsample_templates
:
if
subsample_templates
:
templates_crop_start
=
_randint
(
0
,
num_templates
+
1
)
templates_crop_start
=
_randint
(
0
,
num_templates
)
templates_select_indices
=
torch
.
randperm
(
templates_select_indices
=
torch
.
randperm
(
num_templates
,
device
=
protein
[
"seq_length"
].
device
,
generator
=
g
num_templates
,
device
=
protein
[
"seq_length"
].
device
,
generator
=
g
)
)
...
...
openfold/data/templates.py
View file @
a59ae7c1
...
@@ -130,7 +130,7 @@ def _is_after_cutoff(
...
@@ -130,7 +130,7 @@ def _is_after_cutoff(
else
:
else
:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
# we need to parse, we don't have to worry about returning True here.
logging
.
warning
(
logging
.
info
(
"Template structure not in release dates dict: %s"
,
pdb_id
"Template structure not in release dates dict: %s"
,
pdb_id
)
)
return
False
return
False
...
...
openfold/utils/deepspeed.py
View file @
a59ae7c1
...
@@ -72,8 +72,8 @@ def checkpoint_blocks(
...
@@ -72,8 +72,8 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
#
args = checkpoint(chunker(s, e), *args)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
#
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args
=
wrap
(
args
)
args
=
wrap
(
args
)
return
args
return
args
openfold/utils/loss.py
View file @
a59ae7c1
...
@@ -1464,10 +1464,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1464,10 +1464,7 @@ class AlphaFoldLoss(nn.Module):
for
k
,
loss_fn
in
loss_fns
.
items
():
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
weight
=
self
.
config
[
k
].
weight
if
weight
:
if
weight
:
# print(k)
loss
=
loss_fn
()
loss
=
loss_fn
()
# print(weight * loss)
cum_loss
=
cum_loss
+
weight
*
loss
cum_loss
=
cum_loss
+
weight
*
loss
# print(cum_loss)
return
cum_loss
return
cum_loss
run_pretrained_openfold.py
View file @
a59ae7c1
...
@@ -87,7 +87,7 @@ def main(args):
...
@@ -87,7 +87,7 @@ def main(args):
if
random_seed
is
None
:
if
random_seed
is
None
:
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
random_seed
=
random
.
randrange
(
sys
.
maxsize
)
config
.
data
.
predict
.
num_ensemble
=
num_ensemble
config
.
data
.
predict
.
num_ensemble
=
num_ensemble
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
)
feature_processor
=
feature_pipeline
.
FeaturePipeline
(
config
.
data
)
if
not
os
.
path
.
exists
(
output_dir_base
):
if
not
os
.
path
.
exists
(
output_dir_base
):
os
.
makedirs
(
output_dir_base
)
os
.
makedirs
(
output_dir_base
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
alignment_dir
=
os
.
path
.
join
(
output_dir_base
,
"alignments"
)
...
...
train_openfold.py
View file @
a59ae7c1
import
argparse
import
argparse
from
functools
import
partial
import
json
import
logging
import
logging
import
os
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"4"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"4"
import
random
import
time
import
time
from
typing
import
Optional
import
ml_collections
as
mlc
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
import
torch
import
torch
from
torch.utils.data
import
RandomSampler
torch
.
manual_seed
(
42
)
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.model.model
import
AlphaFold
from
openfold.data.data_modules
import
(
from
openfold.features
import
(
OpenFoldDataModule
,
data_pipeline
,
feature_pipeline
,
mmcif_parsing
,
)
)
from
openfold.features
import
templates
from
openfold.model.model
import
AlphaFold
from
openfold.features.np.utils
import
to_date
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.tensor_utils
import
tensor_tree_map
,
dict_multimap
from
openfold.utils.tensor_utils
import
tensor_tree_map
class
OpenFoldDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
alignment_dir
:
str
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
config
:
mlc
.
ConfigDict
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
mapping_path
:
Optional
[
str
]
=
None
,
mmcif_cache_dir
:
str
=
'tmp/'
,
use_small_bfd
:
bool
=
True
,
seed
:
int
=
42
,
mode
:
str
=
"train"
,
):
"""
Args:
data_dir:
A path to a directory containing mmCIF files (in train
mode) or FASTA files (in inference mode).
alignment_dir:
A path to a directory containing only data in the format
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config:
A dataset config object. See openfold.config
mapping_path:
A json file containing a mapping from consecutive numerical
ids to sample names (matching the directories in data_dir).
Samples not in this mapping are ignored. Can be used to
implement the various training-time filters described in
the AlphaFold supplement
"""
super
(
OpenFoldDataset
,
self
).
__init__
()
self
.
data_dir
=
data_dir
self
.
alignment_dir
=
alignment_dir
self
.
config
=
config
self
.
seed
=
seed
self
.
mode
=
mode
valid_modes
=
[
"train"
,
"val"
,
"predict"
]
if
(
mode
not
in
valid_modes
):
raise
ValueError
(
f
'mode must be one of
{
valid_modes
}
'
)
if
(
mapping_path
is
None
):
self
.
mapping
=
{
str
(
i
):
os
.
path
.
splitext
(
name
)[
0
]
for
i
,
name
in
enumerate
(
os
.
listdir
(
alignment_dir
))
}
else
:
with
open
(
mapping_path
,
'r'
)
as
fp
:
self
.
mapping
=
json
.
load
(
fp
)
template_release_dates_path
=
os
.
path
.
join
(
mmcif_cache_dir
,
"template_release_dates.json"
)
if
(
not
os
.
path
.
exists
(
template_release_dates_path
)):
logging
.
warning
(
"Template release dates cache does not exist. Remember to run "
"scripts/generate_mmcif_caches.py before running OpenFold"
)
template_release_dates_path
=
None
template_featurizer
=
templates
.
TemplateHitFeaturizer
(
mmcif_dir
=
template_mmcif_dir
,
max_template_date
=
max_template_date
,
max_hits
=
(
20
if
(
mode
==
'train'
)
else
4
),
kalign_binary_path
=
kalign_binary_path
,
release_dates_path
=
template_release_dates_path
,
obsolete_pdbs_path
=
None
,
)
self
.
data_pipeline
=
data_pipeline
.
DataPipeline
(
template_featurizer
=
template_featurizer
,
use_small_bfd
=
use_small_bfd
,
)
self
.
feature_pipeline
=
feature_pipeline
.
FeaturePipeline
(
config
)
def
__getitem__
(
self
,
idx
):
no_batch_modes
=
len
(
self
.
config
.
common
.
batch_modes
)
batch_mode_idx
=
idx
%
no_batch_modes
batch_mode_str
=
self
.
config
.
common
.
batch_modes
[
batch_mode_idx
][
0
]
idx
=
int
(
idx
/
no_batch_modes
)
name
=
self
.
mapping
[
str
(
idx
)]
if
(
self
.
mode
==
'train'
or
self
.
mode
==
'val'
):
spl
=
name
.
rsplit
(
'_'
,
1
)
if
(
len
(
spl
)
==
2
):
file_id
,
chain_id
=
spl
else
:
file_id
,
=
spl
chain_id
=
None
path
=
os
.
path
.
join
(
self
.
data_dir
,
file_id
+
'.cif'
)
with
open
(
path
,
'r'
)
as
f
:
mmcif_string
=
f
.
read
()
mmcif_object
=
mmcif_parsing
.
parse
(
file_id
=
file_id
,
mmcif_string
=
mmcif_string
)
# Crash if an error is encountered. Any parsing errors should have
# been dealt with at the alignment stage.
if
(
mmcif_object
.
mmcif_object
is
None
):
raise
list
(
mmcif_object
.
errors
.
values
())[
0
]
mmcif_object
=
mmcif_object
.
mmcif_object
alignment_dir
=
os
.
path
.
join
(
self
.
alignment_dir
,
name
)
data
=
self
.
data_pipeline
.
process_mmcif
(
mmcif
=
mmcif_object
,
alignment_dir
=
alignment_dir
,
chain_id
=
chain_id
,
)
else
:
path
=
os
.
path
.
join
(
name
,
name
+
'.fasta'
)
data
=
self
.
data_pipeline
.
process_fasta
(
fasta_path
=
feats
,
alignment_dir
=
alignment_dir
,
)
feats
=
self
.
feature_pipeline
.
process_features
(
data
,
self
.
mode
,
batch_mode_str
)
return
feats
def
__len__
(
self
):
return
len
(
self
.
mapping
.
keys
())
class
OpenFoldBatchSampler
(
torch
.
utils
.
data
.
BatchSampler
):
"""
A shameful hack.
In AlphaFold, certain batches are designated for loss clamping. The
exact method by residue cropping withing that batch is performed
depends on that designation.
In idiomatic PyTorch, such "batch-wide" properties generally do not
exist; samples are supposed to be generated independently and only
later batched. This class and OpenFoldDataset get around this design
limitation by encoding batch properties in the indices sent to the
Dataset.
While this works (and efficiently), it precludes the future use of an
IterableDataset (such as WebDataset), which doesn't use indices. In
that case, the same can be accomplished by delaying the feature
processing step to the collate_fn, an argument of the DataLoader. That
solution is avoided here because it requires loading an entire batch's
worth of uncropped features into memory at a time.
A third option would be to generate two separate Dataset objects, one
that generates "clamped" batches and another for "unclamped" ones.
However, this would require parsing the precomputed caches of most
proteins twice, once for each loader. Given how lopsided the chances of
drawing a "clamped" batch are, care would also have to be taken not
to allocate too many resources to the less used DataLoader.
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
OpenFoldBatchSampler
,
self
).
__init__
(
**
kwargs
)
self
.
config
=
config
self
.
no_batch_modes
=
len
(
self
.
config
.
common
.
batch_modes
)
def
__iter__
(
self
):
it
=
super
().
__iter__
()
distr
=
torch
.
distributions
.
categorical
.
Categorical
(
torch
.
tensor
(
[
prob
for
name
,
prob
in
self
.
config
.
common
.
batch_modes
]
)
)
for
sample
in
it
:
mode_idx
=
distr
.
sample
().
item
()
sample
=
[
s
*
self
.
no_batch_modes
+
mode_idx
for
s
in
sample
]
yield
sample
class
OpenFoldDataModule
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
config
:
mlc
.
ConfigDict
,
template_mmcif_dir
:
str
,
max_template_date
:
str
,
train_data_dir
:
Optional
[
str
]
=
None
,
train_alignment_dir
:
Optional
[
str
]
=
None
,
val_data_dir
:
Optional
[
str
]
=
None
,
val_alignment_dir
:
Optional
[
str
]
=
None
,
predict_data_dir
:
Optional
[
str
]
=
None
,
predict_alignment_dir
:
Optional
[
str
]
=
None
,
kalign_binary_path
:
str
=
'/usr/bin/kalign'
,
train_mapping_path
:
Optional
[
str
]
=
None
,
mmcif_cache_dir
:
str
=
'tmp/'
,
**
kwargs
):
super
(
OpenFoldDataModule
,
self
).
__init__
()
self
.
config
=
config
self
.
template_mmcif_dir
=
template_mmcif_dir
self
.
max_template_date
=
max_template_date
self
.
train_data_dir
=
train_data_dir
self
.
train_alignment_dir
=
train_alignment_dir
self
.
val_data_dir
=
val_data_dir
self
.
val_alignment_dir
=
val_alignment_dir
self
.
predict_data_dir
=
predict_data_dir
self
.
predict_alignment_dir
=
predict_alignment_dir
self
.
kalign_binary_path
=
kalign_binary_path
self
.
train_mapping_path
=
train_mapping_path
self
.
mmcif_cache_dir
=
mmcif_cache_dir
if
(
self
.
train_data_dir
is
None
and
self
.
predict_data_dir
is
None
):
raise
ValueError
(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self
.
training_mode
=
self
.
train_data_dir
is
not
None
if
(
self
.
training_mode
and
self
.
train_alignment_dir
is
None
):
raise
ValueError
(
'In training mode, train_alignment_dir must be specified'
)
elif
(
not
self
.
training_mode
and
self
.
predict_alingment_dir
is
None
):
raise
ValueError
(
'In inference mode, predict_alignment_dir must be specified'
)
elif
(
val_data_dir
is
not
None
and
val_alignment_dir
is
None
):
raise
ValueError
(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
def
setup
(
self
,
stage
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
max_template_date
=
self
.
max_template_date
,
config
=
self
.
config
,
kalign_binary_path
=
self
.
kalign_binary_path
,
mmcif_cache_dir
=
self
.
mmcif_cache_dir
,
use_small_bfd
=
self
.
config
.
data_module
.
use_small_bfd
,
)
if
(
self
.
training_mode
):
self
.
train_dataset
=
dataset_gen
(
data_dir
=
self
.
train_data_dir
,
alignment_dir
=
self
.
train_alignment_dir
,
mapping_path
=
self
.
train_mapping_path
,
mode
=
'train'
,
)
if
(
self
.
val_data_dir
is
not
None
):
self
.
val_dataset
=
dataset_gen
(
data_dir
=
self
.
val_data_dir
,
alignment_dir
=
self
.
val_alignment_dir
,
mapping_path
=
None
,
mode
=
'val'
,
)
else
:
self
.
predict_dataset
=
dataset_gen
(
data_dir
=
self
.
predict_data_dir
,
alignment_dir
=
self
.
predict_alignment_dir
,
mapping_path
=
None
,
mode
=
'predict'
,
)
def
train_dataloader
(
self
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack
=
lambda
l
:
dict_multimap
(
stack_fn
,
l
)
return
torch
.
utils
.
data
.
DataLoader
(
self
.
train_dataset
,
batch_sampler
=
OpenFoldBatchSampler
(
config
=
self
.
config
,
sampler
=
RandomSampler
(
self
.
train_dataset
),
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
drop_last
=
False
,
),
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
stack
,
)
def
val_dataloader
(
self
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack
=
lambda
l
:
dict_multimap
(
stack_fn
,
l
)
return
torch
.
utils
.
data
.
DataLoader
(
self
.
val_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
stack
)
def
predict_dataloader
(
self
):
stack_fn
=
partial
(
torch
.
stack
,
dim
=
0
)
stack
=
lambda
l
:
dict_multimap
(
stack_fn
,
l
)
return
torch
.
utils
.
data
.
DataLoader
(
self
.
predict_dataset
,
batch_size
=
self
.
config
.
data_module
.
data_loaders
.
batch_size
,
num_workers
=
self
.
config
.
data_module
.
data_loaders
.
num_workers
,
collate_fn
=
stack
)
class
OpenFoldWrapper
(
pl
.
LightningModule
):
class
OpenFoldWrapper
(
pl
.
LightningModule
):
...
@@ -380,6 +61,8 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -380,6 +61,8 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
eps
eps
=
eps
)
)
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
main
(
args
):
def
main
(
args
):
config
=
model_config
(
config
=
model_config
(
...
@@ -421,6 +104,14 @@ if __name__ == "__main__":
...
@@ -421,6 +104,14 @@ if __name__ == "__main__":
help
=
"""Cutoff for all templates. In training mode, templates are also
help
=
"""Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target"""
filtered by the release date of the target"""
)
)
parser
.
add_argument
(
"--distillation_data_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing training PDB files"
)
parser
.
add_argument
(
"--distillation_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed distillation alignments"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--val_data_dir"
,
type
=
str
,
default
=
None
,
"--val_data_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing validation mmCIF files"
help
=
"Directory containing validation mmCIF files"
...
@@ -440,16 +131,20 @@ if __name__ == "__main__":
...
@@ -440,16 +131,20 @@ if __name__ == "__main__":
the training set"""
the training set"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mmcif_cache_dir"
,
type
=
str
,
default
=
"tmp/"
,
"--distillation_mapping_path"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed mmCIF metadata"
help
=
"""See --train_mapping_path"""
)
parser
.
add_argument
(
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"Output of templates.generate_mmcif_cache"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
"--use_small_bfd"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use a reduced version of the BFD database"
help
=
"Whether to use a reduced version of the BFD database"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"Random seed
for the DataModule
"
help
=
"Random seed"
)
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
...
@@ -459,7 +154,9 @@ if __name__ == "__main__":
...
@@ -459,7 +154,9 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Seed torch
if
(
args
.
seed
is
not
None
):
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
+
1
)
np
.
random
.
seed
(
args
.
seed
+
2
)
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment