Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a7274ef0
Commit
a7274ef0
authored
Jul 29, 2022
by
Zhang690683220
Browse files
fix incorrect learning rate warm-up after restarting from ckpt
parent
ce2e1f29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
2 deletions
+25
-2
scripts/zero_to_fp32.py
scripts/zero_to_fp32.py
+12
-0
train_openfold.py
train_openfold.py
+13
-2
No files found.
scripts/zero_to_fp32.py
View file @
a7274ef0
...
@@ -13,6 +13,7 @@ import glob
...
@@ -13,6 +13,7 @@ import glob
import
math
import
math
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
# DeepSpeed data structures it has to be available in the current python environment.
...
@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
...
@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return
model
return
model
def
get_global_step_from_zero_checkpoint
(
checkpoint_dir
):
global_step
=
-
1
latest_path
=
os
.
path
.
join
(
checkpoint_dir
,
'latest'
)
if
os
.
path
.
isfile
(
latest_path
):
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
match
=
re
.
match
(
r
"global_step([0-9]+)"
,
tag
)
global_step
=
int
(
match
.
group
(
1
))
else
:
raise
ValueError
(
f
"Unable to find 'latest' file at
{
latest_path
}
"
)
return
global_step
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
train_openfold.py
View file @
a7274ef0
...
@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
...
@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
gdt_ha
,
gdt_ha
,
)
)
from
scripts.zero_to_fp32
import
(
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
)
)
from
openfold.utils.logger
import
PerformanceLoggingCallback
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
self
.
last_lr_step
=
0
self
.
last_lr_step
=
-
1
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
...
@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
if
self
.
last_lr_step
!=
-
1
:
for
group
in
optimizer
.
param_groups
:
if
'initial_lr'
not
in
group
:
group
[
'initial_lr'
]
=
learning_rate
lr_scheduler
=
AlphaFoldLRScheduler
(
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
optimizer
,
)
)
...
@@ -249,6 +256,10 @@ def main(args):
...
@@ -249,6 +256,10 @@ def main(args):
)
)
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
model_module
.
resume_last_lr_step
(
last_global_step
)
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment