Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4b354151
Commit
4b354151
authored
Jul 20, 2023
by
Geoffrey Yu
Browse files
update train_openfold.py to accomodate training multimer
parent
d886a7be
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
6 deletions
+84
-6
train_openfold.py
train_openfold.py
+84
-6
No files found.
train_openfold.py
View file @
4b354151
...
...
@@ -16,18 +16,18 @@ import torch
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
OpenFoldDataModule
,
OpenFoldDataModule
,
OpenFoldMultimerDataModule
,
DummyDataLoader
,
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.argparse
_utils
import
remove_arguments
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.loss
import
AlphaFoldLoss
,
AlphaFoldMultimerLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
...
...
@@ -257,6 +257,69 @@ class OpenFoldWrapper(pl.LightningModule):
)
class
OpenFoldMultimerWrapper
(
OpenFoldWrapper
):
def
__init__
(
self
,
config
):
super
(
OpenFoldMultimerWrapper
,
self
).
__init__
(
config
)
self
.
config
=
config
self
.
config
.
loss
.
masked_msa
.
num_classes
=
22
# somehow need overwrite this part in multimer loss config
self
.
config
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
self
.
config
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldMultimerLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
training_step
(
self
,
batch
,
batch_idx
):
all_chain_features
,
ground_truth
=
batch
if
(
self
.
ema
.
device
!=
all_chain_features
[
"aatype"
].
device
):
self
.
ema
.
to
(
all_chain_features
[
"aatype"
].
device
)
# Run the model
outputs
=
self
(
all_chain_features
)
# Compute loss
loss
=
self
.
loss
(
outputs
,
(
all_chain_features
,
ground_truth
),
_return_breakdown
=
False
)
# Log it
self
.
_log
(
loss
,
all_chain_features
,
outputs
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
all_chain_features
,
ground_truth
=
batch
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
outputs
=
self
(
all_chain_features
)
# Compute loss and other metrics
all_chain_features
[
"use_clamped_fape"
]
=
0.
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
all_chain_features
,
_return_breakdown
=
True
)
self
.
_log
(
loss_breakdown
,
all_chain_features
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
...
...
@@ -266,8 +329,11 @@ def main(args):
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
model_module
=
OpenFoldWrapper
(
config
)
if
"multimer"
in
args
.
config_preset
:
print
(
"training multimer models now"
)
model_module
=
OpenFoldMultimerWrapper
(
config
)
else
:
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
...
...
@@ -293,11 +359,19 @@ def main(args):
script_preset_
(
model_module
)
#data_module = DummyDataLoader("new_batch.pickle")
data_module
=
OpenFoldDataModule
(
if
"multimer"
in
args
.
config_preset
:
print
(
"use multimer datamodule now"
)
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
else
:
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
.
prepare_data
()
data_module
.
setup
()
...
...
@@ -417,6 +491,10 @@ if __name__ == "__main__":
help
=
'''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser
.
add_argument
(
"--train_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"path to the json file which records all the information of mmcif structures used during training"
)
parser
.
add_argument
(
"--distillation_data_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing training PDB files"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment