Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3dcc01a7
Commit
3dcc01a7
authored
Mar 20, 2022
by
Gustaf Ahdritz
Browse files
Tweak training script. Install new LR scheduler
parent
72a971b0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
33 deletions
+34
-33
deepspeed_config.json
deepspeed_config.json
+2
-9
train_openfold.py
train_openfold.py
+32
-24
No files found.
deepspeed_config.json
View file @
3dcc01a7
{
{
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.001
,
"eps"
:
1e-05
}
},
"fp16"
:
{
"fp16"
:
{
"enabled"
:
tru
e
,
"enabled"
:
fals
e
,
"min_loss_scale"
:
1
"min_loss_scale"
:
1
},
},
"amp"
:
{
"amp"
:
{
...
@@ -15,7 +8,7 @@
...
@@ -15,7 +8,7 @@
"opt_level"
:
"O2"
"opt_level"
:
"O2"
},
},
"bfloat16"
:
{
"bfloat16"
:
{
"enabled"
:
fals
e
"enabled"
:
tru
e
},
},
"zero_optimization"
:
{
"zero_optimization"
:
{
"stage"
:
2
,
"stage"
:
2
,
...
...
train_openfold.py
View file @
3dcc01a7
...
@@ -27,12 +27,13 @@ from openfold.data.data_modules import (
...
@@ -27,12 +27,13 @@ from openfold.data.data_modules import (
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.callbacks
import
(
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
EarlyStoppingVerbose
,
)
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
,
compute_drmsd
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
,
compute_drmsd
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
...
@@ -58,7 +59,6 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -58,7 +59,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
self
.
last_lr_step
=
0
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
...
@@ -72,12 +72,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -72,12 +72,12 @@ class OpenFoldWrapper(pl.LightningModule):
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
)
if
(
train
):
if
(
train
):
self
.
log
(
self
.
log
(
f
"train/loss
_epoch"
,
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
loss_breakdown
[
"loss"
],
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
other_metrics
=
self
.
_compute_validation_metrics
(
...
@@ -116,19 +116,13 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -116,19 +116,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
self
.
ema
.
update
(
self
.
model
)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug (PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
#
load_
state_dict()
is an in-place operation
#
model.
state_dict()
contains references to model weights rather
#
it will change the content in any reference of model.state_dict()
#
than copies. Therefore, we need to clone them before calling
#
therefore we need to explicitly clone the parameters
#
load_state_dict().
clone_param
=
lambda
t
:
t
.
clone
().
detach
()
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
...
@@ -175,15 +169,15 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -175,15 +169,15 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
per_residue
=
False
,
)
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
compute_drmsd
(
drmsd_ca_score
=
compute_drmsd
(
pred_coords_masked_ca
,
pred_coords_masked_ca
,
gt_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
if
(
superimposition_metrics
):
...
@@ -207,11 +201,23 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -207,11 +201,23 @@ class OpenFoldWrapper(pl.LightningModule):
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
)
->
torch
.
optim
.
Adam
:
# Ignored as long as a DeepSpeed optimizer is configured
# Ignored as long as a DeepSpeed optimizer is configured
return
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
self
.
model
.
parameters
(),
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
)
return
{
"optimizer"
:
optimizer
,
"lr_scheduler"
:
{
"scheduler"
:
lr_scheduler
,
"interval"
:
"step"
,
"name"
:
"AlphaFoldLRScheduler"
,
}
}
def
on_load_checkpoint
(
self
,
checkpoint
):
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
...
@@ -236,7 +242,7 @@ def main(args):
...
@@ -236,7 +242,7 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
# TorchScript components of the model
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
...
@@ -255,6 +261,8 @@ def main(args):
...
@@ -255,6 +261,8 @@ def main(args):
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
save_top_k
=-
1
,
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment