Commit 3a4ed753 authored by Jennifer's avatar Jennifer
Browse files

Removes OF copy of zero_to_fp32.py favoring deepspeed.util version

parent e4f9af23
......@@ -22,7 +22,9 @@ import shutil
import torch
from openfold.utils.import_weights import convert_deprecated_v1_keys
from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file
from deepspeed.utils.zero_to_fp32 import (
get_optim_files, parse_optim_states, get_model_state_file
)
def convert_v1_to_v2_weights(args):
......
This diff is collapsed.
......@@ -14,6 +14,7 @@ from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
import torch
import wandb
from deepspeed.utils import zero_to_fp32
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
......@@ -39,11 +40,6 @@ from openfold.utils.import_weights import (
import_jax_weights_,
import_openfold_weights_
)
from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_global_step_from_zero_checkpoint
)
from openfold.utils.logger import PerformanceLoggingCallback
......@@ -276,6 +272,18 @@ class OpenFoldWrapper(pl.LightningModule):
self.model, jax_path, version=model_version
)
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
_DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files
state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
return torch.load(state_file)
def main(args):
if(args.seed is not None):
......@@ -297,7 +305,7 @@ def main(args):
if args.resume_model_weights_only:
# Load the checkpoint
if os.path.isdir(args.resume_from_ckpt):
sd = get_fp32_state_dict_from_zero_checkpoint(
sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
......@@ -316,11 +324,10 @@ def main(args):
else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt):
last_global_step = get_global_step_from_zero_checkpoint(
args.resume_from_ckpt)
sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment