Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4c9d372d
"examples/rust/vscode:/vscode.git/clone" did not exist on "b760c5694df723227dc016e7f35c0fb66955e0d3"
Commit
4c9d372d
authored
Oct 23, 2021
by
Gustaf Ahdritz
Browse files
Add training callbacks
parent
727e68c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
6 deletions
+78
-6
openfold/utils/callbacks.py
openfold/utils/callbacks.py
+14
-0
train_openfold.py
train_openfold.py
+64
-6
No files found.
openfold/utils/callbacks.py
0 → 100644
View file @
4c9d372d
from
pytorch_lightning.utilities
import
rank_zero_info
from
pytorch_lightning.callbacks.early_stopping
import
EarlyStopping
class
EarlyStoppingVerbose
(
EarlyStopping
):
"""
The default EarlyStopping callback's verbose mode is too verbose.
This class outputs a message only when it's getting ready to stop.
"""
def
_evalute_stopping_criteria
(
self
,
*
args
):
should_stop
,
reason
=
super
().
_evalute_stopping_criteria
(
*
args
)
if
(
should_stop
):
rank_zero_info
(
f
"
{
reason
}
\n
"
)
return
should_stop
,
reason
train_openfold.py
View file @
4c9d372d
...
...
@@ -2,7 +2,7 @@ import argparse
import
logging
import
os
#
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"6"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
...
...
@@ -13,6 +13,7 @@ import time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.plugins
import
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
import
torch
...
...
@@ -23,6 +24,9 @@ from openfold.data.data_modules import (
DummyDataLoader
,
)
from
openfold.model.model
import
AlphaFold
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.seed
import
seed_everything
...
...
@@ -88,6 +92,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
on_save_checkpoint
(
self
,
checkpoint
):
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
...
...
@@ -108,7 +115,29 @@ def main(args):
)
data_module
.
prepare_data
()
data_module
.
setup
()
callbacks
=
[]
if
(
args
.
checkpoint_best_val
):
checkpoint_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoints"
)
mc
=
ModelCheckpoint
(
dirpath
=
checkpoint_dir
,
filename
=
"openfold_{epoch}_{step}_{val_loss:.2f}"
,
monitor
=
"val_loss"
,
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
monitor
=
"val_loss"
,
min_delta
=
args
.
min_delta
,
patience
=
args
.
patience
,
verbose
=
False
,
mode
=
"min"
,
check_finite
=
True
,
strict
=
True
,
)
callbacks
.
append
(
es
)
plugins
=
[]
if
(
args
.
deepspeed_config_path
is
not
None
):
plugins
.
append
(
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
))
...
...
@@ -119,6 +148,7 @@ def main(args):
)
trainer
.
fit
(
model_module
,
datamodule
=
data_module
)
trainer
.
save_checkpoint
(
"final.ckpt"
)
if
__name__
==
"__main__"
:
...
...
@@ -135,10 +165,15 @@ if __name__ == "__main__":
"template_mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files to search for templates"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
'''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser
.
add_argument
(
"max_template_date"
,
type
=
str
,
help
=
"""
Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target
"""
help
=
'''
Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target
'''
)
parser
.
add_argument
(
"--distillation_data_dir"
,
type
=
str
,
default
=
None
,
...
...
@@ -162,9 +197,9 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--train_mapping_path"
,
type
=
str
,
default
=
None
,
help
=
"""
Optional path to a .json file containing a mapping from
help
=
'''
Optional path to a .json file containing a mapping from
consecutive numerical indices to sample names. Used to filter
the training set
"""
the training set
'''
)
parser
.
add_argument
(
"--distillation_mapping_path"
,
type
=
str
,
default
=
None
,
...
...
@@ -187,6 +222,24 @@ if __name__ == "__main__":
"--deepspeed_config_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser
.
add_argument
(
"--checkpoint_best_val"
,
type
=
int
,
default
=
True
,
help
=
"""Whether to save the model parameters that perform best during
validation"""
)
parser
.
add_argument
(
"--early_stopping"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to stop training when validation loss fails to decrease"
)
parser
.
add_argument
(
"--min_delta"
,
type
=
float
,
default
=
0
,
help
=
"""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser
.
add_argument
(
"--patience"
,
type
=
int
,
default
=
3
,
help
=
"Early stopping patience"
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
.
set_defaults
(
...
...
@@ -195,4 +248,9 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment