Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fbec92cb
Commit
fbec92cb
authored
Feb 06, 2022
by
Gustaf Ahdritz
Browse files
Update training script
parent
70362e4b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
21 deletions
+17
-21
train_openfold.py
train_openfold.py
+17
-21
No files found.
train_openfold.py
View file @
fbec92cb
...
@@ -12,6 +12,7 @@ import time
...
@@ -12,6 +12,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
...
@@ -68,10 +69,14 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -68,10 +69,14 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
# Log it
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
return
loss
return
loss
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
...
@@ -81,13 +86,17 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -81,13 +86,17 @@ class OpenFoldWrapper(pl.LightningModule):
# Calculate validation loss
# Calculate validation loss
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
l
oss
=
lddt_ca
(
l
ddt_ca_score
=
lddt_ca
(
outputs
[
"final_atom_positions"
],
outputs
[
"final_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
batch
[
"all_atom_mask"
],
eps
=
self
.
config
.
globals
.
eps
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
per_residue
=
False
,
)
)
self
.
log
(
"val/lddt_ca"
,
lddt_ca_score
,
logger
=
True
)
batch
[
"use_clamped_fape"
]
=
0.
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"val/loss"
,
loss
,
logger
=
True
)
self
.
log
(
"val/loss"
,
loss
,
logger
=
True
)
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
...
@@ -106,9 +115,6 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -106,9 +115,6 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
eps
eps
=
eps
)
)
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
on_load_checkpoint
(
self
,
checkpoint
):
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
...
@@ -137,7 +143,7 @@ def main(args):
...
@@ -137,7 +143,7 @@ def main(args):
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
#data_module = DummyDataLoader("batch.pickle")
#data_module = DummyDataLoader("
new_
batch.pickle")
data_module
=
OpenFoldDataModule
(
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
...
@@ -148,22 +154,19 @@ def main(args):
...
@@ -148,22 +154,19 @@ def main(args):
data_module
.
setup
()
data_module
.
setup
()
callbacks
=
[]
callbacks
=
[]
if
(
args
.
checkpoint_best_val
):
if
(
args
.
checkpoint_every_epoch
):
checkpoint_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoints"
)
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
filename
=
"openfold_{epoch}_{step}_{val_loss:.2f}"
,
every_n_epochs
=
1
,
monitor
=
"val/loss"
,
mode
=
"max"
,
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
es
=
EarlyStoppingVerbose
(
monitor
=
"val/l
oss
"
,
monitor
=
"val/l
ddt_ca
"
,
min_delta
=
args
.
min_delta
,
min_delta
=
args
.
min_delta
,
patience
=
args
.
patience
,
patience
=
args
.
patience
,
verbose
=
False
,
verbose
=
False
,
mode
=
"m
in
"
,
mode
=
"m
ax
"
,
check_finite
=
True
,
check_finite
=
True
,
strict
=
True
,
strict
=
True
,
)
)
...
@@ -189,14 +192,8 @@ def main(args):
...
@@ -189,14 +192,8 @@ def main(args):
loggers
.
append
(
wdb_logger
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
if
(
args
.
deepspeed_config_path
is
not
None
):
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy
=
DeepSpeedPlugin
(
strategy
=
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
,
config
=
args
.
deepspeed_config_path
,
# cluster_environment=cluster_environment,
)
)
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
...
@@ -313,9 +310,8 @@ if __name__ == "__main__":
...
@@ -313,9 +310,8 @@ if __name__ == "__main__":
help
=
"Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
help
=
"Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--checkpoint_best_val"
,
type
=
bool_type
,
default
=
True
,
"--checkpoint_every_epoch"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""Whether to save the model parameters that perform best during
help
=
"""Whether to checkpoint at the end of every training epoch"""
validation"""
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--early_stopping"
,
type
=
bool_type
,
default
=
False
,
"--early_stopping"
,
type
=
bool_type
,
default
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment