Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0cf1541c
Commit
0cf1541c
authored
Oct 16, 2023
by
Christina Floristean
Browse files
Refactoring multimer data pipeline and permutation alignment.
parent
377f854c
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
660 additions
and
789 deletions
+660
-789
environment.yml
environment.yml
+2
-0
openfold/config.py
openfold/config.py
+5
-1
openfold/data/data_modules.py
openfold/data/data_modules.py
+371
-438
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+5
-18
openfold/data/input_pipeline_multimer.py
openfold/data/input_pipeline_multimer.py
+25
-20
openfold/utils/loss.py
openfold/utils/loss.py
+207
-230
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+36
-2
train_openfold.py
train_openfold.py
+9
-80
No files found.
environment.yml
View file @
0cf1541c
...
...
@@ -19,6 +19,8 @@ dependencies:
-
deepspeed==0.5.10
-
dm-tree==0.1.6
-
ml-collections==0.1.0
-
jax==0.3.25
-
pandas==2.0.2
-
numpy==1.21.2
-
PyYAML==5.4.1
-
requests==2.26.0
...
...
openfold/config.py
View file @
0cf1541c
...
...
@@ -749,6 +749,9 @@ multimer_config_update = mlc.ConfigDict({
"sym_id"
,
]
},
"supervised"
:
{
"clamp_prob"
:
1.
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
...
...
@@ -765,7 +768,8 @@ multimer_config_update = mlc.ConfigDict({
"max_extra_msa"
:
2048
,
"crop_size"
:
640
,
"spatial_crop_prob"
:
0.5
,
"interface_threshold"
:
10.
"interface_threshold"
:
10.
,
"clamp_prob"
:
1.
,
},
},
"model"
:
{
...
...
openfold/data/data_modules.py
View file @
0cf1541c
This diff is collapsed.
Click to expand it.
openfold/data/feature_pipeline.py
View file @
0cf1541c
...
...
@@ -93,24 +93,11 @@ def np_example_to_features(
with
torch
.
no_grad
():
if
is_multimer
:
if
mode
==
'train'
:
features
,
gt_features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
is_training
=
True
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()},
gt_features
else
:
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
is_training
=
False
)
return
{
k
:
v
for
k
,
v
in
features
.
items
()}
features
=
input_pipeline_multimer
.
process_tensors_from_config
(
tensor_dict
,
cfg
.
common
,
cfg
[
mode
],
)
else
:
features
=
input_pipeline
.
process_tensors_from_config
(
tensor_dict
,
...
...
openfold/data/input_pipeline_multimer.py
View file @
0cf1541c
...
...
@@ -21,16 +21,17 @@ from openfold.data import (
data_transforms_multimer
,
)
def
grountruth_transforms_fns
():
transforms
=
[
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
make_pseudo_beta
(
""
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
,
]
return
transforms
def
groundtruth_transforms_fns
():
transforms
=
[
data_transforms
.
make_atom14_masks
,
data_transforms
.
make_atom14_positions
,
data_transforms
.
atom37_to_frames
,
data_transforms
.
atom37_to_torsion_angles
(
""
),
data_transforms
.
make_pseudo_beta
(
""
),
data_transforms
.
get_backbone_frames
,
data_transforms
.
get_chi_angles
]
return
transforms
def
nonensembled_transform_fns
():
"""Input pipeline data transformers that are not ensembled."""
...
...
@@ -112,20 +113,24 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return
transforms
def
prepare_ground_truth_features
(
tensors
):
"""Prepare ground truth features that are only needed for loss calculation during training"""
GROUNDTRUTH_FEATURES
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
GROUNDTRUTH_FEATURES
}
gt_features
=
[
'all_atom_mask'
,
'all_atom_positions'
,
'asym_id'
,
'sym_id'
,
'entity_id'
]
gt_tensors
=
{
k
:
v
for
k
,
v
in
tensors
.
items
()
if
k
in
gt_features
}
gt_tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
gt_tensors
=
compose
(
grountruth_transforms_fns
())(
gt_tensors
)
gt_tensors
=
compose
(
groun
d
truth_transforms_fns
())(
gt_tensors
)
return
gt_tensors
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
,
is_training
=
False
):
def
process_tensors_from_config
(
tensors
,
common_cfg
,
mode_cfg
):
"""Based on the config, apply filters and transformations to the data."""
if
is_training
:
gt_tensors
=
prepare_ground_truth_features
(
tensors
)
process_gt_feats
=
mode_cfg
.
supervised
gt_tensors
=
{}
if
process_gt_feats
:
gt_tensors
=
prepare_ground_truth_features
(
tensors
)
ensemble_seed
=
random
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
)
tensors
[
'aatype'
]
=
tensors
[
'aatype'
].
to
(
torch
.
long
)
...
...
@@ -152,10 +157,10 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False)
lambda
x
:
wrap_ensemble_fn
(
tensors
,
x
),
torch
.
arange
(
num_recycling
+
1
)
)
if
is_training
:
return
tensors
,
gt_tensors
else
:
return
tensors
if
process_gt_feats
:
tensors
[
'gt_features'
]
=
gt_tensors
return
tensors
@
data_transforms
.
curry1
def
compose
(
x
,
fs
):
...
...
openfold/utils/loss.py
View file @
0cf1541c
This diff is collapsed.
Click to expand it.
scripts/generate_mmcif_cache.py
View file @
0cf1541c
...
...
@@ -13,7 +13,7 @@ from tqdm import tqdm
from
openfold.data.mmcif_parsing
import
parse
def
parse_file
(
f
,
args
):
def
parse_file
(
f
,
args
,
chain_cluster_size_dict
=
None
):
with
open
(
os
.
path
.
join
(
args
.
mmcif_dir
,
f
),
"r"
)
as
fp
:
mmcif_string
=
fp
.
read
()
file_id
=
os
.
path
.
splitext
(
f
)[
0
]
...
...
@@ -28,6 +28,18 @@ def parse_file(f, args):
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
chain_ids
,
seqs
=
list
(
zip
(
*
mmcif
.
chain_to_seqres
.
items
()))
if
chain_cluster_size_dict
is
not
None
:
cluster_sizes
=
[]
for
chain_id
in
chain_ids
:
full_name
=
"_"
.
join
([
file_id
,
chain_id
])
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
cluster_sizes
.
append
(
cluster_size
)
local_data
[
"cluster_sizes"
]
=
cluster_sizes
local_data
[
"chain_ids"
]
=
chain_ids
local_data
[
"seqs"
]
=
seqs
local_data
[
"no_chains"
]
=
len
(
chain_ids
)
...
...
@@ -38,8 +50,21 @@ def parse_file(f, args):
def
main
(
args
):
chain_cluster_size_dict
=
None
if
args
.
cluster_file
is
not
None
:
chain_cluster_size_dict
=
{}
with
open
(
args
.
cluster_file
,
"r"
)
as
fp
:
clusters
=
[
l
.
strip
()
for
l
in
fp
.
readlines
()]
for
cluster
in
clusters
:
chain_ids
=
cluster
.
split
()
cluster_len
=
len
(
chain_ids
)
for
chain_id
in
chain_ids
:
chain_id
=
chain_id
.
upper
()
chain_cluster_size_dict
[
chain_id
]
=
cluster_len
files
=
[
f
for
f
in
os
.
listdir
(
args
.
mmcif_dir
)
if
".cif"
in
f
]
fn
=
partial
(
parse_file
,
args
=
args
)
fn
=
partial
(
parse_file
,
args
=
args
,
chain_cluster_size_dict
=
chain_cluster_size_dict
)
data
=
{}
with
Pool
(
processes
=
args
.
no_workers
)
as
p
:
with
tqdm
(
total
=
len
(
files
))
as
pbar
:
...
...
@@ -63,6 +88,15 @@ if __name__ == "__main__":
"--no_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of workers to use for parsing"
)
parser
.
add_argument
(
"--cluster_file"
,
type
=
str
,
default
=
None
,
help
=
(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser
.
add_argument
(
"--chunksize"
,
type
=
int
,
default
=
10
,
help
=
"How many files should be distributed to each worker at a time"
...
...
train_openfold.py
View file @
0cf1541c
import
argparse
import
logging
import
os
import
random
import
sys
import
time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
OpenFoldDataModule
,
OpenFoldMultimerDataModule
,
DummyDataLoader
,
)
from
openfold.data.data_modules
import
OpenFoldDataModule
,
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
...
...
@@ -53,7 +46,12 @@ class OpenFoldWrapper(pl.LightningModule):
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
if
self
.
config
.
globals
.
is_multimer
:
self
.
loss
=
AlphaFoldMultimerLoss
(
config
.
loss
)
else
:
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
...
...
@@ -256,72 +254,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
class
OpenFoldMultimerWrapper
(
OpenFoldWrapper
):
def
__init__
(
self
,
config
):
super
(
OpenFoldMultimerWrapper
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldMultimerLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
training_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# Log it
if
(
self
.
ema
.
device
!=
features
[
"aatype"
].
device
):
self
.
ema
.
to
(
features
[
"aatype"
].
device
)
# Run the model
outputs
=
self
(
features
)
# Remove the recycling dimension
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
(
features
,
gt_features
),
_return_breakdown
=
True
)
# Log it
self
.
_log
(
loss_breakdown
,
features
,
outputs
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
outputs
=
self
(
features
)
# Compute loss and other metrics
features
[
"use_clamped_fape"
]
=
0.
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
(
features
,
gt_features
),
_return_breakdown
=
True
)
self
.
_log
(
loss_breakdown
,
features
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
...
...
@@ -331,10 +263,8 @@ def main(args):
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
if
"multimer"
in
args
.
config_preset
:
model_module
=
OpenFoldMultimerWrapper
(
config
)
else
:
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
...
...
@@ -359,7 +289,6 @@ def main(args):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
#data_module = DummyDataLoader("new_batch.pickle")
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment