Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6298a3e6
Commit
6298a3e6
authored
Oct 23, 2021
by
Gustaf Ahdritz
Browse files
Fix PDB parsing, add distillation MSA cropping
parent
4d40ce80
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
21 deletions
+27
-21
openfold/config.py
openfold/config.py
+4
-0
openfold/data/data_modules.py
openfold/data/data_modules.py
+3
-2
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+4
-7
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+7
-9
openfold/data/input_pipeline.py
openfold/data/input_pipeline.py
+7
-1
train_openfold.py
train_openfold.py
+2
-2
No files found.
openfold/config.py
View file @
6298a3e6
...
...
@@ -179,6 +179,7 @@ config = mlc.ConfigDict(
"all_atom_positions"
,
"resolution"
,
"use_clamped_fape"
,
"is_distillation"
,
],
},
"predict"
:
{
...
...
@@ -192,6 +193,7 @@ config = mlc.ConfigDict(
"crop"
:
False
,
"crop_size"
:
None
,
"supervised"
:
False
,
"subsample_recycling"
:
False
,
},
"eval"
:
{
"fixed_size"
:
True
,
...
...
@@ -204,6 +206,7 @@ config = mlc.ConfigDict(
"crop"
:
False
,
"crop_size"
:
None
,
"supervised"
:
True
,
"subsample_recycling"
:
False
,
},
"train"
:
{
"fixed_size"
:
True
,
...
...
@@ -218,6 +221,7 @@ config = mlc.ConfigDict(
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"subsample_recycling"
:
True
,
"max_distillation_msa_clusters"
:
1000
,
},
"data_module"
:
{
"use_small_bfd"
:
False
,
...
...
openfold/data/data_modules.py
View file @
6298a3e6
import
copy
from
functools
import
partial
import
json
import
logging
...
...
@@ -462,9 +463,9 @@ class DummyDataset(torch.utils.data.Dataset):
class
DummyDataLoader
(
pl
.
LightningDataModule
):
def
__init__
(
self
):
def
__init__
(
self
,
batch_path
):
super
().
__init__
()
self
.
dataset
=
Dataset
()
self
.
dataset
=
Dummy
Dataset
(
batch_path
)
def
train_dataloader
(
self
):
return
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
)
openfold/data/data_pipeline.py
View file @
6298a3e6
...
...
@@ -113,6 +113,7 @@ def make_pdb_features(
pdb_feats
[
"all_atom_positions"
]
=
all_atom_positions
pdb_feats
[
"all_atom_mask"
]
=
all_atom_mask
pdb_feats
[
"resolution"
]
=
np
.
array
([
0.
]).
astype
(
np
.
float32
)
pdb_feats
[
"is_distillation"
]
=
np
.
array
(
1.
).
astype
(
np
.
float32
)
return
pdb_feats
...
...
@@ -412,16 +413,12 @@ class DataPipeline:
pdb_feats
=
make_pdb_features
(
protein_object
)
mmcif_feats
=
make_mmcif_features
(
mmcif
,
chain_id
)
alignments
=
self
.
_parse_alignment_output
(
alignment_dir
)
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
templates_result
=
self
.
template_featurizer
.
get_templates
(
query_sequence
=
input_sequenc
e
,
query_sequence
=
protein_object
.
aatyp
e
,
query_pdb_code
=
None
,
query_release_date
=
to_date
(
mmcif
.
header
[
"release_date"
])
,
query_release_date
=
None
,
hits
=
alignments
[
"hhsearch_hits"
],
)
...
...
@@ -438,4 +435,4 @@ class DataPipeline:
),
)
return
{
**
mmcif
_feats
,
**
templates_result
.
features
,
**
msa_features
}
return
{
**
pdb
_feats
,
**
templates_result
.
features
,
**
msa_features
}
openfold/data/data_transforms.py
View file @
6298a3e6
...
...
@@ -77,14 +77,6 @@ def curry1(f):
return
fc
@
curry1
def
add_distillation_flag
(
protein
,
distillation
):
protein
[
"is_distillation"
]
=
torch
.
tensor
(
float
(
distillation
),
dtype
=
torch
.
float32
)
return
protein
def
make_all_atom_aatype
(
protein
):
protein
[
"all_atom_aatype"
]
=
protein
[
"aatype"
]
return
protein
...
...
@@ -176,7 +168,6 @@ def randomly_replace_msa_with_unknown(protein, replace_proportion):
)
return
protein
@
curry1
def
sample_msa
(
protein
,
max_seq
,
keep_extra
):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
...
...
@@ -198,6 +189,13 @@ def sample_msa(protein, max_seq, keep_extra):
return
protein
@
curry1
def
sample_msa_distillation
(
protein
,
max_seq
):
if
(
protein
[
"is_distillation"
]
==
1
):
protein
=
sample_msa
(
protein
,
max_seq
,
keep_extra
=
False
)
return
protein
@
curry1
def
crop_extra_msa
(
protein
,
max_extra_msa
):
num_seq
=
protein
[
"extra_msa"
].
shape
[
0
]
...
...
openfold/data/input_pipeline.py
View file @
6298a3e6
...
...
@@ -25,7 +25,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
transforms
=
[
data_transforms
.
cast_to_64bit_ints
,
data_transforms
.
correct_msa_restypes
,
data_transforms
.
add_distillation_flag
(
False
),
data_transforms
.
squeeze_features
,
data_transforms
.
randomly_replace_msa_with_unknown
(
0.0
),
data_transforms
.
make_seq_mask
,
...
...
@@ -72,6 +71,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms
=
[]
if
"max_distillation_msa_clusters"
in
mode_cfg
:
transforms
.
append
(
data_transforms
.
sample_msa_distillation
(
mode_cfg
.
max_distillation_msa_clusters
)
)
if
common_cfg
.
reduce_msa_clusters_by_max_templates
:
pad_msa_clusters
=
mode_cfg
.
max_msa_clusters
-
mode_cfg
.
max_templates
else
:
...
...
train_openfold.py
View file @
6298a3e6
...
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"6"
#
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
...
...
@@ -223,7 +223,7 @@ if __name__ == "__main__":
help
=
"Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser
.
add_argument
(
"--checkpoint_best_val"
,
type
=
int
,
default
=
True
,
"--checkpoint_best_val"
,
type
=
bool
,
default
=
True
,
help
=
"""Whether to save the model parameters that perform best during
validation"""
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment