Commit 577219c1 authored by Jennifer's avatar Jennifer
Browse files

Removes OF copy of zero_to_fp32.py favoring deepspeed.util version

parent 86263583
...@@ -22,7 +22,9 @@ import shutil ...@@ -22,7 +22,9 @@ import shutil
import torch import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys from openfold.utils.import_weights import convert_deprecated_v1_keys
from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file from deepspeed.utils.zero_to_fp32 import (
get_optim_files, parse_optim_states, get_model_state_file
)
def convert_v1_to_v2_weights(args): def convert_v1_to_v2_weights(args):
......
This diff is collapsed.
...@@ -11,6 +11,7 @@ from pytorch_lightning.loggers import WandbLogger ...@@ -11,6 +11,7 @@ from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
import torch import torch
from deepspeed.utils import zero_to_fp32
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
...@@ -35,11 +36,6 @@ from openfold.utils.import_weights import ( ...@@ -35,11 +36,6 @@ from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
import_openfold_weights_ import_openfold_weights_
) )
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback from openfold.utils.logger import PerformanceLoggingCallback
...@@ -274,6 +270,18 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -274,6 +270,18 @@ class OpenFoldWrapper(pl.LightningModule):
self.model, jax_path, version=model_version self.model, jax_path, version=model_version
) )
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
_DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files
state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
return torch.load(state_file)
def main(args): def main(args):
if (args.seed is not None): if (args.seed is not None):
...@@ -314,11 +322,10 @@ def main(args): ...@@ -314,11 +322,10 @@ def main(args):
else: # Loads a checkpoint to start from a specific time step else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt): if os.path.isdir(args.resume_from_ckpt):
last_global_step = get_global_step_from_zero_checkpoint( sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
args.resume_from_ckpt)
else: else:
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step']) last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step) model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...") logging.info("Successfully loaded last lr step...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment