Unverified Commit cfd0fc6e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #76 from aqlaboratory/chunking_experiment_rebased

parents c9e0f894 2726892a
...@@ -4,19 +4,19 @@ from datetime import date ...@@ -4,19 +4,19 @@ from datetime import date
def add_data_args(parser: argparse.ArgumentParser): def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'uniref90_database_path', type=str, '--uniref90_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'mgnify_database_path', type=str, '--mgnify_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'pdb70_database_path', type=str, '--pdb70_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'template_mmcif_dir', type=str, '--template_mmcif_dir', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'uniclust30_database_path', type=str, '--uniclust30_database_path', type=str, default=None,
) )
parser.add_argument( parser.add_argument(
'--bfd_database_path', type=str, default=None, '--bfd_database_path', type=str, default=None,
......
...@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
assert torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps) assert(torch.max(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
assert torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps) assert(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
class TestExtraMSAStack(unittest.TestCase): class TestExtraMSAStack(unittest.TestCase):
...@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
blocks_per_ckpt=None, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval()
...@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase): ...@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].core.msa_transition(
.msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
) )
.cpu() .cpu()
) )
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) print(out_gt)
print(out_repro)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase): ...@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
print(torch.mean(torch.abs(out_gt - out_repro)))
print(torch.max(torch.abs(out_gt - out_repro))) print(torch.max(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
...@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].msa_att_row(
.msa_att_row(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
z=torch.as_tensor(pair_act).cuda(), z=torch.as_tensor(pair_act).cuda(),
chunk_size=4, chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(), mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() ).cpu()
)
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
...@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0] model.evoformer.blocks[0].msa_att_col(
.msa_att_col(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
chunk_size=4, chunk_size=4,
mask=torch.as_tensor(msa_mask).cuda(), mask=torch.as_tensor(msa_mask).cuda(),
) )
.cpu() ).cpu()
)
print(torch.mean(torch.abs(out_gt - out_repro)))
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
...@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.extra_msa_stack.stack.blocks[0] model.extra_msa_stack.blocks[0].msa_att_col(
.msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
chunk_size=4, chunk_size=4,
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
......
...@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].tri_att_start model.evoformer.blocks[0].core.tri_att_start
if starting if starting
else model.evoformer.blocks[0].tri_att_end else model.evoformer.blocks[0].core.tri_att_end
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
...@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None, chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self): def test_tri_att_end_compare(self):
......
...@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].tri_mul_in model.evoformer.blocks[0].core.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].tri_mul_out else model.evoformer.blocks[0].core.tri_mul_out
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
...@@ -13,6 +13,7 @@ import time ...@@ -13,6 +13,7 @@ import time
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch import torch
...@@ -29,7 +30,7 @@ from openfold.utils.callbacks import ( ...@@ -29,7 +30,7 @@ from openfold.utils.callbacks import (
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
...@@ -66,21 +67,28 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -66,21 +67,28 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("loss", loss)
return {"loss": loss} self.log("train/loss", loss, on_step=True, logger=True)
return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
self.cached_weights = self.model.state_dict() self.cached_weights = self.model.state_dict()
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Calculate validation loss # Calculate validation loss
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch) loss = lddt_ca(
self.log("val_loss", loss, prog_bar=True) outputs["final_atom_positions"],
return {"val_loss": loss} batch["all_atom_positions"],
batch["all_atom_mask"],
eps=self.config.globals.eps,
per_residue=False,
)
self.log("val/loss", loss, logger=True)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
...@@ -101,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -101,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model) self.ema.update(self.model)
def on_load_checkpoint(self, checkpoint):
self.ema.load_state_dict(checkpoint["ema"])
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict() checkpoint["ema"] = self.ema.state_dict()
...@@ -140,15 +151,15 @@ def main(args): ...@@ -140,15 +151,15 @@ def main(args):
if(args.checkpoint_best_val): if(args.checkpoint_best_val):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints") checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
mc = ModelCheckpoint( mc = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="openfold_{epoch}_{step}_{val_loss:.2f}", filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val_loss", monitor="val/loss",
mode="max",
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if(args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val_loss", monitor="val/loss",
min_delta=args.min_delta, min_delta=args.min_delta,
patience=args.patience, patience=args.patience,
verbose=False, verbose=False,
...@@ -157,7 +168,7 @@ def main(args): ...@@ -157,7 +168,7 @@ def main(args):
strict=True, strict=True,
) )
callbacks.append(es) callbacks.append(es)
if(args.log_performance): if(args.log_performance):
global_batch_size = args.num_nodes * args.gpus global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback( perf = PerformanceLoggingCallback(
...@@ -166,24 +177,41 @@ def main(args): ...@@ -166,24 +177,41 @@ def main(args):
) )
callbacks.append(perf) callbacks.append(perf)
loggers = []
if(args.wandb):
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
if "SLURM_JOB_ID" in os.environ: #if "SLURM_JOB_ID" in os.environ:
cluster_environment = SLURMEnvironment() # cluster_environment = SLURMEnvironment()
else: #else:
cluster_environment = None # cluster_environment = None
strategy = DeepSpeedPlugin( strategy = DeepSpeedPlugin(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
cluster_environment=cluster_environment, # cluster_environment=cluster_environment,
) )
elif (args.gpus is not None and args.gpus) > 1 or args.num_nodes > 1: if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPPlugin(find_unused_parameters=False)
else: else:
strategy = None strategy = None
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
default_root_dir=args.output_dir,
strategy=strategy, strategy=strategy,
callbacks=callbacks, callbacks=callbacks,
logger=loggers,
) )
if(args.resume_model_weights_only): if(args.resume_model_weights_only):
...@@ -198,7 +226,7 @@ def main(args): ...@@ -198,7 +226,7 @@ def main(args):
) )
trainer.save_checkpoint( trainer.save_checkpoint(
os.path.join(trainer.logger.log_dir, "checkpoints", "final.ckpt") os.path.join(args.output_dir, "checkpoints", "final.ckpt")
) )
...@@ -318,10 +346,37 @@ if __name__ == "__main__": ...@@ -318,10 +346,37 @@ if __name__ == "__main__":
"--log_performance", type=bool_type, default=False, "--log_performance", type=bool_type, default=False,
help="Measure performance" help="Measure performance"
) )
parser.add_argument(
"--wandb", action="store_true", default=False,
)
parser.add_argument(
"--experiment_name", type=str, default=None,
)
parser.add_argument(
"--wandb_id", type=str, default=None,
)
parser.add_argument(
"--wandb_project", type=str, default=None,
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
)
parser.add_argument( parser.add_argument(
"--script_modules", type=bool_type, default=False, "--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model" help="Whether to TorchScript eligible components of them model"
) )
parser.add_argument(
"--train_prot_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_prot_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
...@@ -330,7 +385,15 @@ if __name__ == "__main__": ...@@ -330,7 +385,15 @@ if __name__ == "__main__":
) )
# Remove some buggy/redundant arguments introduced by the Trainer # Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments(parser, ["--accelerator", "--resume_from_checkpoint"]) remove_arguments(
parser,
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs",
]
)
args = parser.parse_args() args = parser.parse_args()
...@@ -339,4 +402,7 @@ if __name__ == "__main__": ...@@ -339,4 +402,7 @@ if __name__ == "__main__":
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment