Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
87f3cd45
Unverified
Commit
87f3cd45
authored
Aug 02, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Aug 02, 2022
Browse files
Merge pull request #182 from Zhang690683220/fix
fix incorrect learning rate warm-up after restarting from ckpt
parents
2648f26a
a2e7dabb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
2 deletions
+32
-2
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+12
-0
train_openfold.py
train_openfold.py
+20
-2
No files found.
scripts/zero_to_fp32.py
View file @
87f3cd45
...
@@ -13,6 +13,7 @@ import glob
...
@@ -13,6 +13,7 @@ import glob
import
math
import
math
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
# DeepSpeed data structures it has to be available in the current python environment.
...
@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
...
@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return
model
return
model
def
get_global_step_from_zero_checkpoint
(
checkpoint_dir
):
global_step
=
-
1
latest_path
=
os
.
path
.
join
(
checkpoint_dir
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
match
=
re
.
match
(
r
"global_step([0-9]+)"
,
tag
)
global_step
=
int
(
match
.
group
(
1
))
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
return
global_step
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
train_openfold.py
View file @
87f3cd45
...
@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
...
@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
gdt_ha
,
gdt_ha
,
)
)
from
scripts.zero_to_fp32
import
(
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
)
)
from
openfold.utils.logger
import
PerformanceLoggingCallback
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
self
.
last_lr_step
=
0
self
.
last_lr_step
=
-
1
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
...
@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
if
self
.
last_lr_step
!=
-
1
:
for
group
in
optimizer
.
param_groups
:
if
'initial_lr'
not
in
group
:
group
[
'initial_lr'
]
=
learning_rate
lr_scheduler
=
AlphaFoldLRScheduler
(
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
optimizer
,
)
)
...
@@ -237,6 +244,9 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -237,6 +244,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_save_checkpoint
(
self
,
checkpoint
):
def
on_save_checkpoint
(
self
,
checkpoint
):
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
def
resume_last_lr_step
(
self
,
lr_step
):
self
.
last_lr_step
=
lr_step
def
main
(
args
):
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
if
(
args
.
seed
is
not
None
):
...
@@ -249,6 +259,14 @@ def main(args):
...
@@ -249,6 +259,14 @@ def main(args):
)
)
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
last_global_step
=
int
(
sd
[
'global_step'
])
model_module
.
resume_last_lr_step
(
last_global_step
)
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment