import argparse import logging import os #os.environ["CUDA_VISIBLE_DEVICES"] = "7" #os.environ["MASTER_ADDR"]="10.119.81.14" #os.environ["MASTER_PORT"]="42069" #os.environ["NODE_RANK"]="0" import random import time import numpy as np import pytorch_lightning as pl from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.plugins.training_type import DeepSpeedPlugin import torch from openfold.config import model_config from openfold.data.data_modules import ( OpenFoldDataModule, DummyDataLoader, ) from openfold.model.model import AlphaFold from openfold.utils.callbacks import ( EarlyStoppingVerbose, ) from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.argparse import remove_arguments from openfold.utils.loss import AlphaFoldLoss from openfold.utils.seed import seed_everything from openfold.utils.tensor_utils import tensor_tree_map from scripts.zero_to_fp32 import ( get_fp32_state_dict_from_zero_checkpoint ) from openfold.utils.logger import PerformanceLoggingCallback class OpenFoldWrapper(pl.LightningModule): def __init__(self, config): super(OpenFoldWrapper, self).__init__() self.config = config self.model = AlphaFold(config) self.loss = AlphaFoldLoss(config.loss) self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) def forward(self, batch): return self.model(batch) def training_step(self, batch, batch_idx): if(self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) # Run the model outputs = self(batch) # Remove the recycling dimension batch = tensor_tree_map(lambda t: t[..., -1], batch) # Compute loss loss = self.loss(outputs, batch) #if(torch.isnan(loss) or torch.isinf(loss)): # logging.warning("loss is NaN. Skipping example...") # loss = loss.new_tensor(0., requires_grad=True) return {"loss": loss} def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights if(self.cached_weights is None): self.cached_weights = model.state_dict() self.model.load_state_dict(self.ema.state_dict()["params"]) # Calculate validation loss outputs = self(batch) batch = tensor_tree_map(lambda t: t[..., -1], batch) loss = self.loss(outputs, batch) return {"val_loss": loss} def validation_epoch_end(self, _): # Restore the model weights to normal self.model.load_state_dict(self.cached_weights) self.cached_weights = None def configure_optimizers(self, learning_rate: float = 1e-3, eps: float = 1e-8 ) -> torch.optim.Adam: # Ignored as long as a DeepSpeed optimizer is configured return torch.optim.Adam( self.model.parameters(), lr=learning_rate, eps=eps ) def on_before_zero_grad(self, *args, **kwargs): self.ema.update(self.model) def on_save_checkpoint(self, checkpoint): checkpoint["ema"] = self.ema.state_dict() def main(args): if(args.seed is not None): seed_everything(args.seed) config = model_config( "model_1", train=True, low_prec=(args.precision == 16) ) model_module = OpenFoldWrapper(config) if(args.resume_from_ckpt and args.resume_model_weights_only): sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) sd = {k[len("module."):]:v for k,v in sd.items()} model_module.load_state_dict(sd) logging.info("Successfully loaded model weights...") #data_module = DummyDataLoader("batch.pickle") data_module = OpenFoldDataModule( config=config.data, batch_seed=args.seed, **vars(args) ) data_module.prepare_data() data_module.setup() callbacks = [] if(args.checkpoint_best_val): checkpoint_dir = os.path.join(args.output_dir, "checkpoints") mc = ModelCheckpoint( dirpath=checkpoint_dir, filename="openfold_{epoch}_{step}_{val_loss:.2f}", monitor="val_loss", ) callbacks.append(mc) if(args.early_stopping): es = EarlyStoppingVerbose( monitor="val_loss", min_delta=args.min_delta, patience=args.patience, verbose=False, mode="min", check_finite=True, strict=True, ) callbacks.append(es) if args.log_performance: global_batch_size = args.num_nodes * args.gpus perf = PerformanceLoggingCallback( log_dir=args.output_dir, global_batch_size=global_batch_size, ) callbacks.append(perf) if(args.deepspeed_config_path is not None): strategy = DeepSpeedPlugin(config=args.deepspeed_config_path) elif args.gpus > 1 or args.num_nodes > 1: strategy = "ddp" else: strategy = None trainer = pl.Trainer.from_argparse_args( args, strategy=strategy, ) if(args.resume_model_weights_only): ckpt_path = None else: ckpt_path = args.resume_from_ckpt trainer.fit( model_module, datamodule=data_module, ckpt_path=ckpt_path, ) trainer.save_checkpoint( os.path.join(trainer.logger.log_dir, "checkpoints", "final.ckpt") ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "train_data_dir", type=str, help="Directory containing training mmCIF files" ) parser.add_argument( "train_alignment_dir", type=str, help="Directory containing precomputed training alignments" ) parser.add_argument( "template_mmcif_dir", type=str, help="Directory containing mmCIF files to search for templates" ) parser.add_argument( "output_dir", type=str, help='''Directory in which to output checkpoints, logs, etc. Ignored if not on rank 0''' ) parser.add_argument( "max_template_date", type=str, help='''Cutoff for all templates. In training mode, templates are also filtered by the release date of the target''' ) parser.add_argument( "--distillation_data_dir", type=str, default=None, help="Directory containing training PDB files" ) parser.add_argument( "--distillation_alignment_dir", type=str, default=None, help="Directory containing precomputed distillation alignments" ) parser.add_argument( "--val_data_dir", type=str, default=None, help="Directory containing validation mmCIF files" ) parser.add_argument( "--val_alignment_dir", type=str, default=None, help="Directory containing precomputed validation alignments" ) parser.add_argument( "--kalign_binary_path", type=str, default='/usr/bin/kalign', help="Path to the kalign binary" ) parser.add_argument( "--train_mapping_path", type=str, default=None, help='''Optional path to a .json file containing a mapping from consecutive numerical indices to sample names. Used to filter the training set''' ) parser.add_argument( "--distillation_mapping_path", type=str, default=None, help="""See --train_mapping_path""" ) parser.add_argument( "--template_release_dates_cache_path", type=str, default=None, help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF files.""" ) parser.add_argument( "--use_small_bfd", type=bool, default=False, help="Whether to use a reduced version of the BFD database" ) parser.add_argument( "--seed", type=int, default=None, help="Random seed" ) parser.add_argument( "--deepspeed_config_path", type=str, default=None, help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled" ) parser.add_argument( "--checkpoint_best_val", type=bool, default=True, help="""Whether to save the model parameters that perform best during validation""" ) parser.add_argument( "--early_stopping", type=bool, default=False, help="Whether to stop training when validation loss fails to decrease" ) parser.add_argument( "--min_delta", type=float, default=0, help="""The smallest decrease in validation loss that counts as an improvement for the purposes of early stopping""" ) parser.add_argument( "--patience", type=int, default=3, help="Early stopping patience" ) parser.add_argument( "--resume_from_ckpt", type=str, default=None, help="Path to a model checkpoint from which to restore training state" ) parser.add_argument( "--resume_model_weights_only", type=bool, default=False, help="Whether to load just model weights as opposed to training state" ) parser.add_argument( "--log_performance", action='store_true', help="Measure performance" ) parser = pl.Trainer.add_argparse_args(parser) # Disable the initial validation pass parser.set_defaults( num_sanity_val_steps=0, ) # Remove some buggy/redundant arguments introduced by the Trainer remove_arguments(parser, ["--accelerator", "--resume_from_checkpoint"]) args = parser.parse_args() if(args.seed is None and ((args.gpus is not None and args.gpus > 1) or (args.num_nodes is not None and args.num_nodes > 1))): raise ValueError("For distributed training, --seed must be specified") main(args)