Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3dcc01a7
"git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "f24a5f708a86906514fbb775b0ff1e878524d2d6"
Commit
3dcc01a7
authored
Mar 20, 2022
by
Gustaf Ahdritz
Browse files
Tweak training script. Install new LR scheduler
parent
72a971b0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
33 deletions
+34
-33
deepspeed_config.json
deepspeed_config.json
+2
-9
train_openfold.py
train_openfold.py
+32
-24
No files found.
deepspeed_config.json
View file @
3dcc01a7
{
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.001
,
"eps"
:
1e-05
}
},
"fp16"
:
{
"enabled"
:
tru
e
,
"enabled"
:
fals
e
,
"min_loss_scale"
:
1
},
"amp"
:
{
...
...
@@ -15,7 +8,7 @@
"opt_level"
:
"O2"
},
"bfloat16"
:
{
"enabled"
:
fals
e
"enabled"
:
tru
e
},
"zero_optimization"
:
{
"stage"
:
2
,
...
...
train_openfold.py
View file @
3dcc01a7
...
...
@@ -27,12 +27,13 @@ from openfold.data.data_modules import (
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
,
compute_drmsd
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
...
...
@@ -58,7 +59,6 @@ class OpenFoldWrapper(pl.LightningModule):
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
0
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
...
...
@@ -72,12 +72,12 @@ class OpenFoldWrapper(pl.LightningModule):
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"train/loss
_epoch"
,
loss_breakdown
[
"loss"
],
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
...
...
@@ -116,19 +116,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug (PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
#
load_
state_dict()
is an in-place operation
#
it will change the content in any reference of model.state_dict()
#
therefore we need to explicitly clone the parameters
clone_param
=
lambda
t
:
t
.
clone
().
detach
()
#
model.
state_dict()
contains references to model weights rather
#
than copies. Therefore, we need to clone them before calling
#
load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
...
...
@@ -175,15 +169,15 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
compute_drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
...
...
@@ -207,11 +201,23 @@ class OpenFoldWrapper(pl.LightningModule):
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# Ignored as long as a DeepSpeed optimizer is configured
return
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
learning_rate
,
eps
=
eps
)
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
)
return
{
"optimizer"
:
optimizer
,
"lr_scheduler"
:
{
"scheduler"
:
lr_scheduler
,
"interval"
:
"step"
,
"name"
:
"AlphaFoldLRScheduler"
,
}
}
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
...
...
@@ -236,7 +242,7 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
# TorchScript components of the model
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
...
...
@@ -255,6 +261,8 @@ def main(args):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
save_top_k
=-
1
,
)
callbacks
.
append
(
mc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment