Commit bbf989a7 authored by Jennifer's avatar Jennifer
Browse files

psivant local pl2 upgrades without mpi configuration

parents d8418293 ce000c60
name: openfold-venv name: pytorch1-plupgrade
channels: channels:
- conda-forge - conda-forge
- bioconda - bioconda
...@@ -11,7 +11,7 @@ dependencies: ...@@ -11,7 +11,7 @@ dependencies:
- openmm=7.7 - openmm=7.7
- pdbfixer - pdbfixer
- cudatoolkit==11.3.* - cudatoolkit==11.3.*
- pytorch-lightning==1.5.10 - pytorch-lightning==2.0.9
- biopython==1.79 - biopython==1.79
- numpy==1.21 - numpy==1.21
- pandas==2.0 - pandas==2.0
...@@ -19,11 +19,12 @@ dependencies: ...@@ -19,11 +19,12 @@ dependencies:
- requests - requests
- scipy==1.7 - scipy==1.7
- tqdm==4.62.2 - tqdm==4.62.2
- typing-extensions==3.10 - typing-extensions==4.0
- wandb==0.12.21 - wandb==0.12.21
- modelcif==0.7 - modelcif==0.7
- awscli - awscli
- ml-collections - ml-collections
- mkl==2024.0.0
- aria2 - aria2
- git - git
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
......
...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(distillation_alignment_index_path, "r") as fp: with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp) self.distillation_alignment_index = json.load(fp)
def setup(self): def setup(self, stage=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict", mode="predict",
) )
def _gen_dataloader(self, stage): def _gen_dataloader(self, stage=None):
generator = None generator = None
if self.batch_seed is not None: if self.batch_seed is not None:
generator = torch.Generator() generator = torch.Generator()
...@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
def val_dataloader(self): def val_dataloader(self):
if self.eval_dataset is not None: if self.eval_dataset is not None:
return self._gen_dataloader("eval") return self._gen_dataloader("eval")
return None return []
def predict_dataloader(self): def predict_dataloader(self):
return self._gen_dataloader("predict") return self._gen_dataloader("predict")
...@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): ...@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.training_mode = self.train_data_dir is not None self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self): def setup(self, setup=None):
# Most of the arguments are the same for the three datasets # Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset, dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=self.template_mmcif_dir,
......
...@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features return features
...@@ -590,7 +590,7 @@ def convert_monomer_features( ...@@ -590,7 +590,7 @@ def convert_monomer_features(
) -> FeatureDict: ) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models.""" """Reshapes and modifies monomer features for multimer models."""
converted = {} converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = { unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length' 'sequence', 'domain_name', 'num_alignments', 'seq_length'
} }
...@@ -1296,7 +1296,7 @@ class DataPipelineMultimer: ...@@ -1296,7 +1296,7 @@ class DataPipelineMultimer:
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=object
) )
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
...@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords): ...@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords): def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy() reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().cpu().numpy() coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np) superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd) return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
......
...@@ -9,8 +9,11 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor ...@@ -9,8 +9,11 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks import DeviceStatsMonitor from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
import torch import torch
import wandb
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
...@@ -25,7 +28,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage ...@@ -25,7 +28,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import ( from openfold.utils.validation_metrics import (
...@@ -60,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -60,7 +62,7 @@ class OpenFoldWrapper(pl.LightningModule):
self.cached_weights = None self.cached_weights = None
self.last_lr_step = -1 self.last_lr_step = -1
self.save_hyperparameters self.save_hyperparameters()
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
...@@ -71,14 +73,17 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -71,14 +73,17 @@ class OpenFoldWrapper(pl.LightningModule):
self.log( self.log(
f"{phase}/{loss_name}", f"{phase}/{loss_name}",
indiv_loss, indiv_loss,
on_step=train, on_epoch=(not train), logger=True, prog_bar=(loss_name == 'loss'),
# on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
on_step=train, on_epoch=(not train), logger=True, sync_dist=True,
) )
if(train): if(train):
self.log( self.log(
f"{phase}/{loss_name}_epoch", f"{phase}/{loss_name}_epoch",
indiv_loss, indiv_loss,
on_step=False, on_epoch=True, logger=True, # on_step=False, on_epoch=True, logger=True, sync_dist=False,
on_step=False, on_epoch=True, logger=True, sync_dist=True,
) )
with torch.no_grad(): with torch.no_grad():
...@@ -92,7 +97,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -92,7 +97,9 @@ class OpenFoldWrapper(pl.LightningModule):
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
torch.mean(v), torch.mean(v),
on_step=False, on_epoch=True, logger=True prog_bar = (k == 'loss'),
# on_step=False, on_epoch=True, logger=True, sync_dist=False,
on_step=False, on_epoch=True, logger=True, sync_dist=True,
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
...@@ -155,7 +162,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -155,7 +162,7 @@ class OpenFoldWrapper(pl.LightningModule):
self._log(loss_breakdown, batch, outputs, train=False) self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _): def on_validation_epoch_end(self):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
...@@ -378,41 +385,58 @@ def main(args): ...@@ -378,41 +385,58 @@ def main(args):
callbacks.append(lr_monitor) callbacks.append(lr_monitor)
loggers = [] loggers = []
is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
if(args.wandb): if(args.wandb):
if args.mpi_plugin and is_rank_zero:
wandb_init_dict = dict(
name=args.experiment_name,
project=args.wandb_project,
id=args.wandb_id,
dir=args.output_dir,
resume="allow",
anonymous=None,
entity=args.wandb_entity
)
wandb.run = wandb.init(**wandb_init_dict)
wdb_logger = WandbLogger( wdb_logger = WandbLogger(
name=args.experiment_name, name=args.experiment_name,
save_dir=args.output_dir, save_dir=args.output_dir,
id=args.wandb_id, id=args.wandb_id,
project=args.wandb_project, project=args.wandb_project,
config=config.to_dict(),
**{"entity": args.wandb_entity} **{"entity": args.wandb_entity}
) )
loggers.append(wdb_logger) loggers.append(wdb_logger)
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
if(args.deepspeed_config_path is not None): if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin( strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path, config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
) )
if(args.wandb): if(args.wandb and is_rank_zero):
wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py") wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPStrategy(find_unused_parameters=False,
cluster_environment=cluster_environment)
else: else:
strategy = None strategy = None
if(args.wandb): if(args.wandb and is_rank_zero):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt" freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}") os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}") wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer(
args, num_nodes=args.num_nodes,
devices=args.gpus,
precision=args.precision,
max_epochs=args.max_epochs,
default_root_dir=args.output_dir, default_root_dir=args.output_dir,
strategy=strategy, strategy=strategy,
callbacks=callbacks, callbacks=callbacks,
logger=loggers, logger=loggers,
profiler='simple',
) )
if (args.resume_model_weights_only): if (args.resume_model_weights_only):
...@@ -623,8 +647,17 @@ if __name__ == "__main__": ...@@ -623,8 +647,17 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
) )
parser = pl.Trainer.add_argparse_args(parser) parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--gpus", type=int, default=None)
parser.add_argument("--max_epochs", type=int, default=None)
parser.add_argument("--precision", type=str, default="32")
parser.add_argument("--log_every_n_steps", type=int, default=50)
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
parser.add_argument("--flush_logs_every_n_steps", type=int, default=5)
parser.add_argument("--num_sanity_val_steps", type=int, default=0)
parser.add_argument("--mpi_plugin", action="store_true", default=False)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
parser.set_defaults( parser.set_defaults(
num_sanity_val_steps=0, num_sanity_val_steps=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment