Commit df4dfacb authored by Jennifer's avatar Jennifer
Browse files

first pass changes to run with pl 2.1

parent e813bb53
......@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self):
def setup(self, stage=None):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
......@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict",
)
def _gen_dataloader(self, stage):
def _gen_dataloader(self, stage=None):
generator = None
if self.batch_seed is not None:
generator = torch.Generator()
......@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def val_dataloader(self):
if self.eval_dataset is not None:
return self._gen_dataloader("eval")
return None
# Temp fix to pass the validation step
return []
def predict_dataloader(self):
return self._gen_dataloader("predict")
......@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self):
def setup(self, setup=None):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir,
......
......@@ -2,7 +2,7 @@ import os
import logging
import random
import numpy as np
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import seed_everything
from openfold.utils.suppress_output import SuppressLogging
......
......@@ -6,6 +6,7 @@ import sys
import unittest
import numpy as np
import torch
from openfold.config import model_config
from openfold.model.model import AlphaFold
......@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function"
)
return params
def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff = torch.abs(expected - actual)
err = compare_func(abs_diff)
zero_tensor = torch.tensor(0, dtype=err.dtype)
rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6
torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)
def assert_max_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.max, expected, actual, eps)
def assert_mean_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.mean, expected, actual, eps)
......@@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
def test_compare_model(self):
"""
......@@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.mean(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
if __name__ == "__main__":
......
......@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
# Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
......@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
unittest.main()
......@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnAttention(unittest.TestCase):
......@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase):
......@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4)
if __name__ == "__main__":
......
......@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase):
......@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase):
......
......@@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False,
).cpu()
diff = torch.max(torch.abs(out_gt - out_repro))
self.assertTrue(diff < consts.eps,
msg=f"Found difference between ground truth and reproduction of {diff}")
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class Template(unittest.TestCase):
......@@ -286,7 +284,7 @@ class Template(unittest.TestCase):
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
......
......@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe=True, _inplace_chunk_size=4,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
......
......@@ -7,7 +7,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
import torch
from openfold.config import model_config
......@@ -71,7 +71,7 @@ class OpenFoldWrapper(pl.LightningModule):
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
if (train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
......@@ -85,7 +85,7 @@ class OpenFoldWrapper(pl.LightningModule):
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
for k, v in other_metrics.items():
self.log(
f"{phase}/{k}",
torch.mean(v),
......@@ -93,7 +93,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
if (self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None)
......@@ -124,12 +124,13 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
if (self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
def clone_param(t): return t.detach().clone()
self.cached_weights = tensor_tree_map(
clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None)
......@@ -152,7 +153,7 @@ class OpenFoldWrapper(pl.LightningModule):
self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _):
def on_validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
......@@ -194,7 +195,7 @@ class OpenFoldWrapper(pl.LightningModule):
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
if (superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
......@@ -215,11 +216,11 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate: float = 1e-3,
eps: float = 1e-5,
) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam(
self.model.parameters(),
......@@ -247,8 +248,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"]
if(not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
if (not self.model.template_config.enabled):
ema["params"] = {k: v for k,
v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint):
......@@ -270,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule):
def main(args):
if(args.seed is not None):
if (args.seed is not None):
seed_everything(args.seed)
config = model_config(
......@@ -280,28 +282,31 @@ def main(args):
)
model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
if (args.resume_from_ckpt):
if (os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if(args.resume_from_ckpt and args.resume_model_weights_only):
if(os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt)
if (args.resume_from_ckpt and args.resume_model_weights_only):
if (os.path.isdir(args.resume_from_ckpt)):
sd = get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()}
sd = {k[len("module."):]: v for k, v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params):
if (args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
logging.info(
f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
if (args.script_modules):
script_preset_(model_module)
if "multimer" in args.config_preset:
......@@ -321,7 +326,7 @@ def main(args):
data_module.setup()
callbacks = []
if(args.checkpoint_every_epoch):
if (args.checkpoint_every_epoch):
mc = ModelCheckpoint(
every_n_epochs=1,
auto_insert_metric_name=False,
......@@ -329,7 +334,7 @@ def main(args):
)
callbacks.append(mc)
if(args.early_stopping):
if (args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val/lddt_ca",
min_delta=args.min_delta,
......@@ -341,7 +346,7 @@ def main(args):
)
callbacks.append(es)
if(args.log_performance):
if (args.log_performance):
global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"),
......@@ -349,12 +354,12 @@ def main(args):
)
callbacks.append(perf)
if(args.log_lr):
if (args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
loggers = []
if(args.wandb):
if (args.wandb):
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
......@@ -364,32 +369,37 @@ def main(args):
)
loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin(
if (args.deepspeed_config_path is not None):
strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path,
)
if(args.wandb):
if (args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False)
strategy = DDPStrategy(find_unused_parameters=False)
else:
strategy = None
if(args.wandb):
if (args.wandb):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args(
args,
default_root_dir=args.output_dir,
strategy=strategy,
callbacks=callbacks,
logger=loggers,
)
if(args.resume_model_weights_only):
# Raw dump of all args from pl.Trainer constructor
trainer_kws = set([
'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir',
])
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
trainer_args.update({
'default_root_dir': args.output_dir,
'strategy': strategy,
'callbacks': callbacks,
'logger': loggers,
})
trainer = pl.Trainer(**trainer_args)
if (args.resume_model_weights_only):
ckpt_path = None
else:
ckpt_path = args.resume_from_ckpt
......@@ -594,36 +604,59 @@ if __name__ == "__main__":
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser.set_defaults(
num_sanity_val_steps=0,
parser.add_argument(
"--num_nodes", type=int, default=1,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments(
parser,
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs",
]
parser.add_argument(
"--gpus", type=int, default=1,
)
parser.add_argument(
"--precision", type=str, default=None,
)
parser.add_argument(
"--replace_sampler_ddp", type=bool_type, default=True,
)
parser.add_argument(
"--max_epochs", type=int, default=1,
)
parser.add_argument(
"--log_every_n_steps", type=int, default=25,
)
parser.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
# parser = pl.Trainer.add_argparse_args(parser)
#
# # Disable the initial validation pass
# parser.set_defaults(
# num_sanity_val_steps=0,
# )
# # Remove some buggy/redundant arguments introduced by the Trainer
# remove_arguments(
# parser,
# [
# "--accelerator",
# "--resume_from_checkpoint",
# "--reload_dataloaders_every_epoch",
# "--reload_dataloaders_every_n_epochs",
# ]
# )
args = parser.parse_args()
if(args.seed is None and
if (args.seed is None and
((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment