Unverified Commit 87f3cd45 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #182 from Zhang690683220/fix

fix incorrect learning rate warm-up after restarting from ckpt
parents 2648f26a a2e7dabb
......@@ -13,6 +13,7 @@ import glob
import math
import os
from collections import OrderedDict
import re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
......@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return model
def get_global_step_from_zero_checkpoint(checkpoint_dir):
global_step = -1
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
match = re.match(r"global_step([0-9]+)", tag)
global_step = int(match.group(1))
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return global_step
if __name__ == "__main__":
......
......@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
gdt_ha,
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback
......@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
self.cached_weights = None
self.last_lr_step = 0
self.last_lr_step = -1
def forward(self, batch):
return self.model(batch)
......@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
)
......@@ -237,6 +244,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def main(args):
if(args.seed is not None):
......@@ -249,6 +259,14 @@ def main(args):
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment