Commit a7274ef0 authored by Zhang690683220's avatar Zhang690683220
Browse files

fix incorrect learning rate warm-up after restarting from ckpt

parent ce2e1f29
......@@ -13,6 +13,7 @@ import glob
import math
import os
from collections import OrderedDict
import re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
......@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return model
def get_global_step_from_zero_checkpoint(checkpoint_dir):
global_step = -1
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
match = re.match(r"global_step([0-9]+)", tag)
global_step = int(match.group(1))
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return global_step
if __name__ == "__main__":
......
......@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
gdt_ha,
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback
......@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
self.cached_weights = None
self.last_lr_step = 0
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
......@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
......@@ -249,6 +256,10 @@ def main(args):
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment