Unverified Commit 87f3cd45 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #182 from Zhang690683220/fix

fix incorrect learning rate warm-up after restarting from ckpt
parents 2648f26a a2e7dabb
...@@ -13,6 +13,7 @@ import glob ...@@ -13,6 +13,7 @@ import glob
import math import math
import os import os
from collections import OrderedDict from collections import OrderedDict
import re
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment. # DeepSpeed data structures it has to be available in the current python environment.
...@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): ...@@ -431,6 +432,17 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return model return model
def get_global_step_from_zero_checkpoint(checkpoint_dir):
global_step = -1
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
match = re.match(r"global_step([0-9]+)", tag)
global_step = int(match.group(1))
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
return global_step
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import ( ...@@ -44,7 +44,8 @@ from openfold.utils.validation_metrics import (
gdt_ha, gdt_ha,
) )
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
) )
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
...@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -61,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
self.cached_weights = None self.cached_weights = None
self.last_lr_step = 0 self.last_lr_step = -1
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
...@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -215,6 +216,12 @@ class OpenFoldWrapper(pl.LightningModule):
lr=learning_rate, lr=learning_rate,
eps=eps eps=eps
) )
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler( lr_scheduler = AlphaFoldLRScheduler(
optimizer, optimizer,
) )
...@@ -237,6 +244,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -237,6 +244,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict() checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def main(args): def main(args):
if(args.seed is not None): if(args.seed is not None):
...@@ -249,6 +259,14 @@ def main(args): ...@@ -249,6 +259,14 @@ def main(args):
) )
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only): if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment