Commit 6dc34d71 authored by Jennifer's avatar Jennifer
Browse files

first pass changes to run with pl 2.1

parent 5f5a79a7
...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self, stage=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict", mode="predict",
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage=None):
generator = None generator = None
if self.batch_seed is not None: if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
...@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
if self.eval_dataset is not None: if self.eval_dataset is not None:
return self._gen_dataloader("eval") return self._gen_dataloader("eval")
return None # Temp fix to pass the validation step
return []
def predict_dataloader(self): def predict_dataloader(self):
return self._gen_dataloader("predict") return self._gen_dataloader("predict")
...@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self): def setup(self, setup=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset, dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import logging import logging
import random import random
import numpy as np import numpy as np
from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning import seed_everything
from openfold.utils.suppress_output import SuppressLogging from openfold.utils.suppress_output import SuppressLogging
......
...@@ -8,7 +8,7 @@ import pytorch_lightning as pl ...@@ -8,7 +8,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
import torch import torch
from openfold.config import model_config from openfold.config import model_config
...@@ -56,7 +56,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -56,7 +56,7 @@ class OpenFoldWrapper(pl.LightningModule):
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay model=self.model, decay=config.ema.decay
) )
self.cached_weights = None self.cached_weights = None
self.last_lr_step = -1 self.last_lr_step = -1
self.save_hyperparameters self.save_hyperparameters
...@@ -68,12 +68,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -68,12 +68,12 @@ class OpenFoldWrapper(pl.LightningModule):
phase = "train" if train else "val" phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items(): for loss_name, indiv_loss in loss_breakdown.items():
self.log( self.log(
f"{phase}/{loss_name}", f"{phase}/{loss_name}",
indiv_loss, indiv_loss,
on_step=train, on_epoch=(not train), logger=True, on_step=train, on_epoch=(not train), logger=True,
) )
if(train): if (train):
self.log( self.log(
f"{phase}/{loss_name}_epoch", f"{phase}/{loss_name}_epoch",
indiv_loss, indiv_loss,
...@@ -82,12 +82,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -82,12 +82,12 @@ class OpenFoldWrapper(pl.LightningModule):
with torch.no_grad(): with torch.no_grad():
other_metrics = self._compute_validation_metrics( other_metrics = self._compute_validation_metrics(
batch, batch,
outputs, outputs,
superimposition_metrics=(not train) superimposition_metrics=(not train)
) )
for k,v in other_metrics.items(): for k, v in other_metrics.items():
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
torch.mean(v), torch.mean(v),
...@@ -95,7 +95,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -95,7 +95,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device): if (self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None) ground_truth = batch.pop('gt_features', None)
...@@ -126,12 +126,13 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -126,12 +126,13 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if (self.cached_weights is None):
# model.state_dict() contains references to model weights rather # model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling # than copies. Therefore, we need to clone them before calling
# load_state_dict(). # load_state_dict().
clone_param = lambda t: t.detach().clone() def clone_param(t): return t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) self.cached_weights = tensor_tree_map(
clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None) ground_truth = batch.pop('gt_features', None)
...@@ -153,23 +154,23 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -153,23 +154,23 @@ class OpenFoldWrapper(pl.LightningModule):
) )
self._log(loss_breakdown, batch, outputs, train=False) self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _): def on_validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
def _compute_validation_metrics(self, def _compute_validation_metrics(self,
batch, batch,
outputs, outputs,
superimposition_metrics=False superimposition_metrics=False
): ):
metrics = {} metrics = {}
gt_coords = batch["all_atom_positions"] gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"] pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"] all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later # This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None] gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None]
...@@ -177,7 +178,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -177,7 +178,7 @@ class OpenFoldWrapper(pl.LightningModule):
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos] all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca( lddt_ca_score = lddt_ca(
pred_coords, pred_coords,
gt_coords, gt_coords,
...@@ -185,18 +186,18 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -185,18 +186,18 @@ class OpenFoldWrapper(pl.LightningModule):
eps=self.config.globals.eps, eps=self.config.globals.eps,
per_residue=False, per_residue=False,
) )
metrics["lddt_ca"] = lddt_ca_score metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd( drmsd_ca_score = drmsd(
pred_coords_masked_ca, pred_coords_masked_ca,
gt_coords_masked_ca, gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n mask=all_atom_mask_ca, # still required here to compute n
) )
metrics["drmsd_ca"] = drmsd_ca_score metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics): if (superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose( superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
) )
...@@ -210,22 +211,22 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -210,22 +211,22 @@ class OpenFoldWrapper(pl.LightningModule):
metrics["alignment_rmsd"] = alignment_rmsd metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score metrics["gdt_ha"] = gdt_ha_score
return metrics return metrics
def configure_optimizers(self, def configure_optimizers(self,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-5, eps: float = 1e-5,
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# return torch.optim.Adam( # return torch.optim.Adam(
# self.model.parameters(), # self.model.parameters(),
# lr=learning_rate, # lr=learning_rate,
# eps=eps # eps=eps
# ) # )
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
lr=learning_rate, lr=learning_rate,
eps=eps eps=eps
) )
...@@ -250,8 +251,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -250,8 +251,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"] ema = checkpoint["ema"]
if(not self.model.template_config.enabled): if (not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} ema["params"] = {k: v for k,
v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema) self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
...@@ -262,23 +264,23 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -262,23 +264,23 @@ class OpenFoldWrapper(pl.LightningModule):
def load_from_jax(self, jax_path): def load_from_jax(self, jax_path):
model_basename = os.path.splitext( model_basename = os.path.splitext(
os.path.basename( os.path.basename(
os.path.normpath(jax_path) os.path.normpath(jax_path)
) )
)[0] )[0]
model_version = "_".join(model_basename.split("_")[1:]) model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_( import_jax_weights_(
self.model, jax_path, version=model_version self.model, jax_path, version=model_version
) )
def main(args): def main(args):
if(args.seed is not None): if (args.seed is not None):
seed_everything(args.seed) seed_everything(args.seed)
config = model_config( config = model_config(
args.config_preset, args.config_preset,
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=(str(args.precision) == "16")
) )
if args.experiment_config_json: if args.experiment_config_json:
...@@ -321,30 +323,31 @@ def main(args): ...@@ -321,30 +323,31 @@ def main(args):
if args.resume_from_jax_params: if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") logging.info(
f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if (args.script_modules):
script_preset_(model_module) script_preset_(model_module)
if "multimer" in args.config_preset: if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule( data_module = OpenFoldMultimerDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
else: else:
data_module = OpenFoldDataModule( data_module = OpenFoldDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
callbacks = [] callbacks = []
if(args.checkpoint_every_epoch): if (args.checkpoint_every_epoch):
mc = ModelCheckpoint( mc = ModelCheckpoint(
every_n_epochs=1, every_n_epochs=1,
auto_insert_metric_name=False, auto_insert_metric_name=False,
...@@ -352,7 +355,7 @@ def main(args): ...@@ -352,7 +355,7 @@ def main(args):
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if (args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val/lddt_ca", monitor="val/lddt_ca",
min_delta=args.min_delta, min_delta=args.min_delta,
...@@ -364,7 +367,7 @@ def main(args): ...@@ -364,7 +367,7 @@ def main(args):
) )
callbacks.append(es) callbacks.append(es)
if(args.log_performance): if (args.log_performance):
global_batch_size = args.num_nodes * args.gpus global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback( perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"), log_file=os.path.join(args.output_dir, "performance_log.json"),
...@@ -372,12 +375,12 @@ def main(args): ...@@ -372,12 +375,12 @@ def main(args):
) )
callbacks.append(perf) callbacks.append(perf)
if(args.log_lr): if (args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step") lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor) callbacks.append(lr_monitor)
loggers = [] loggers = []
if(args.wandb): if (args.wandb):
wdb_logger = WandbLogger( wdb_logger = WandbLogger(
name=args.experiment_name, name=args.experiment_name,
save_dir=args.output_dir, save_dir=args.output_dir,
...@@ -388,38 +391,43 @@ def main(args): ...@@ -388,38 +391,43 @@ def main(args):
) )
loggers.append(wdb_logger) loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None): if (args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin( strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
) )
if(args.wandb): if (args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py") wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPStrategy(find_unused_parameters=False)
else: else:
strategy = None strategy = None
if(args.wandb): if (args.wandb):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt" freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}") os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}") wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args( # Raw dump of all args from pl.Trainer constructor
args, trainer_kws = set([
default_root_dir=args.output_dir, 'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir',
strategy=strategy, ])
callbacks=callbacks, trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
logger=loggers, trainer_args.update({
) 'default_root_dir': args.output_dir,
'strategy': strategy,
if(args.resume_model_weights_only): 'callbacks': callbacks,
'logger': loggers,
})
trainer = pl.Trainer(**trainer_args)
if (args.resume_model_weights_only):
ckpt_path = None ckpt_path = None
else: else:
ckpt_path = args.resume_from_ckpt ckpt_path = args.resume_from_ckpt
trainer.fit( trainer.fit(
model_module, model_module,
datamodule=data_module, datamodule=data_module,
ckpt_path=ckpt_path, ckpt_path=ckpt_path,
) )
...@@ -621,36 +629,59 @@ if __name__ == "__main__": ...@@ -621,36 +629,59 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
) )
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument(
"--num_nodes", type=int, default=1,
# Disable the initial validation pass )
parser.set_defaults( parser.add_argument(
num_sanity_val_steps=0, "--gpus", type=int, default=1,
) )
parser.add_argument(
# Remove some buggy/redundant arguments introduced by the Trainer "--precision", type=str, default=None,
remove_arguments( )
parser, parser.add_argument(
[ "--replace_sampler_ddp", type=bool_type, default=True,
"--accelerator", )
"--resume_from_checkpoint", parser.add_argument(
"--reload_dataloaders_every_epoch", "--max_epochs", type=int, default=1,
"--reload_dataloaders_every_n_epochs", )
] parser.add_argument(
) "--log_every_n_steps", type=int, default=25,
)
parser.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
# parser = pl.Trainer.add_argparse_args(parser)
#
# # Disable the initial validation pass
# parser.set_defaults(
# num_sanity_val_steps=0,
# )
# # Remove some buggy/redundant arguments introduced by the Trainer
# remove_arguments(
# parser,
# [
# "--accelerator",
# "--resume_from_checkpoint",
# "--reload_dataloaders_every_epoch",
# "--reload_dataloaders_every_n_epochs",
# ]
# )
args = parser.parse_args() args = parser.parse_args()
if(args.seed is None and if (args.seed is None and
((args.gpus is not None and args.gpus > 1) or ((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None): if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible") raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch # This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1 args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment