Commit 7de0ab00 authored by Jennifer's avatar Jennifer Committed by Jennifer Wei
Browse files

first pass changes to run with pl 2.1

parent a51b08cd
...@@ -73,7 +73,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -73,7 +73,7 @@ class OpenFoldWrapper(pl.LightningModule):
on_step=train, on_epoch=(not train), logger=True, sync_dist=False, on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
) )
if(train): if (train):
self.log( self.log(
f"{phase}/{loss_name}_epoch", f"{phase}/{loss_name}_epoch",
indiv_loss, indiv_loss,
...@@ -87,7 +87,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -87,7 +87,7 @@ class OpenFoldWrapper(pl.LightningModule):
superimposition_metrics=(not train) superimposition_metrics=(not train)
) )
for k,v in other_metrics.items(): for k, v in other_metrics.items():
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
torch.mean(v), torch.mean(v),
...@@ -96,7 +96,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -96,7 +96,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device): if (self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None) ground_truth = batch.pop('gt_features', None)
...@@ -127,12 +127,13 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -127,12 +127,13 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if (self.cached_weights is None):
# model.state_dict() contains references to model weights rather # model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling # than copies. Therefore, we need to clone them before calling
# load_state_dict(). # load_state_dict().
clone_param = lambda t: t.detach().clone() def clone_param(t): return t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) self.cached_weights = tensor_tree_map(
clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None) ground_truth = batch.pop('gt_features', None)
...@@ -197,7 +198,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -197,7 +198,7 @@ class OpenFoldWrapper(pl.LightningModule):
metrics["drmsd_ca"] = drmsd_ca_score metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics): if (superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose( superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
) )
...@@ -246,8 +247,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -246,8 +247,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"] ema = checkpoint["ema"]
if(not self.model.template_config.enabled): if (not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} ema["params"] = {k: v for k,
v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema) self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
...@@ -331,10 +333,11 @@ def main(args): ...@@ -331,10 +333,11 @@ def main(args):
if args.resume_from_jax_params: if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") logging.info(
f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if (args.script_modules):
script_preset_(model_module) script_preset_(model_module)
if "multimer" in args.config_preset: if "multimer" in args.config_preset:
...@@ -354,7 +357,7 @@ def main(args): ...@@ -354,7 +357,7 @@ def main(args):
data_module.setup() data_module.setup()
callbacks = [] callbacks = []
if(args.checkpoint_every_epoch): if (args.checkpoint_every_epoch):
mc = ModelCheckpoint( mc = ModelCheckpoint(
every_n_epochs=1, every_n_epochs=1,
auto_insert_metric_name=False, auto_insert_metric_name=False,
...@@ -362,7 +365,7 @@ def main(args): ...@@ -362,7 +365,7 @@ def main(args):
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if (args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val/lddt_ca", monitor="val/lddt_ca",
min_delta=args.min_delta, min_delta=args.min_delta,
...@@ -374,7 +377,7 @@ def main(args): ...@@ -374,7 +377,7 @@ def main(args):
) )
callbacks.append(es) callbacks.append(es)
if(args.log_performance): if (args.log_performance):
global_batch_size = args.num_nodes * args.gpus global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback( perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"), log_file=os.path.join(args.output_dir, "performance_log.json"),
...@@ -382,7 +385,7 @@ def main(args): ...@@ -382,7 +385,7 @@ def main(args):
) )
callbacks.append(perf) callbacks.append(perf)
if(args.log_lr): if (args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step") lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor) callbacks.append(lr_monitor)
...@@ -686,16 +689,17 @@ if __name__ == "__main__": ...@@ -686,16 +689,17 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if(args.seed is None and if (args.seed is None and
((args.gpus is not None and args.gpus > 1) or ((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None): if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible") raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment