Commit ce000c60 authored by Jennifer Wei's avatar Jennifer Wei
Browse files

changes required for pytorch2

parent fdd4e1d8
name: openfold-venv
name: pytorch1-plupgrade
channels:
- conda-forge
- bioconda
......@@ -11,7 +11,7 @@ dependencies:
- openmm=7.7
- pdbfixer
- cudatoolkit==11.3.*
- pytorch-lightning==1.5.10
- pytorch-lightning==2.0.9
- biopython==1.79
- numpy==1.21
- pandas==2.0
......@@ -19,7 +19,7 @@ dependencies:
- requests
- scipy==1.7
- tqdm==4.62.2
- typing-extensions==3.10
- typing-extensions==4.0
- wandb==0.12.21
- modelcif==0.7
- awscli
......@@ -31,6 +31,7 @@ dependencies:
- bioconda::kalign2==2.04
- pytorch::pytorch=1.12.*
- pip:
- mpi4py==3.1.5
- deepspeed==0.12.4
- dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git
......
......@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self):
def setup(self, stage=None):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
......@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode="predict",
)
def _gen_dataloader(self, stage):
def _gen_dataloader(self, stage=None):
generator = None
if self.batch_seed is not None:
generator = torch.Generator()
......@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
def val_dataloader(self):
if self.eval_dataset is not None:
return self._gen_dataloader("eval")
return None
return []
def predict_dataloader(self):
return self._gen_dataloader("predict")
......@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self.training_mode = self.train_data_dir is not None
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
def setup(self):
def setup(self, setup=None):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleMultimerDataset,
template_mmcif_dir=self.template_mmcif_dir,
......
......@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
return features
......@@ -590,7 +590,7 @@ def convert_monomer_features(
) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
converted = {}
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
unnecessary_leading_dim_feats = {
'sequence', 'domain_name', 'num_alignments', 'seq_length'
}
......@@ -1296,7 +1296,7 @@ class DataPipelineMultimer:
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=object
)
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
......@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy()
coords_np = coords.detach().cpu().numpy()
reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
......
......@@ -8,8 +8,11 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
import torch
import wandb
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
......@@ -24,7 +27,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
......@@ -59,7 +61,7 @@ class OpenFoldWrapper(pl.LightningModule):
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters
self.save_hyperparameters()
def forward(self, batch):
return self.model(batch)
......@@ -70,14 +72,15 @@ class OpenFoldWrapper(pl.LightningModule):
self.log(
f"{phase}/{loss_name}",
indiv_loss,
on_step=train, on_epoch=(not train), logger=True,
prog_bar=(loss_name == 'loss'),
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
on_step=False, on_epoch=True, logger=True, sync_dist=False,
)
with torch.no_grad():
......@@ -91,7 +94,8 @@ class OpenFoldWrapper(pl.LightningModule):
self.log(
f"{phase}/{k}",
torch.mean(v),
on_step=False, on_epoch=True, logger=True
prog_bar = (k == 'loss'),
on_step=False, on_epoch=True, logger=True, sync_dist=False,
)
def training_step(self, batch, batch_idx):
......@@ -154,7 +158,7 @@ class OpenFoldWrapper(pl.LightningModule):
self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _):
def on_validation_epoch_end(self):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
......@@ -377,40 +381,59 @@ def main(args):
callbacks.append(lr_monitor)
loggers = []
is_rank_zero = int(os.environ.get("PMI_RANK")) == 0
if(args.wandb):
if args.mpi_plugin and is_rank_zero:
wandb_init_dict = dict(
name=args.experiment_name,
project=args.wandb_project,
id=args.wandb_id,
dir=args.output_dir,
resume="allow",
anonymous=None,
entity=args.wandb_entity
)
wandb.run = wandb.init(**wandb_init_dict)
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
config=config.to_dict(),
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedPlugin(
strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
)
if(args.wandb):
if(args.wandb and is_rank_zero):
wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False)
strategy = DDPStrategy(find_unused_parameters=False,
cluster_environment=cluster_environment)
else:
strategy = None
if(args.wandb):
if(args.wandb and is_rank_zero):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer.from_argparse_args(
args,
trainer = pl.Trainer(
num_nodes=args.num_nodes,
devices=args.gpus,
precision=args.precision,
max_epochs=args.max_epochs,
default_root_dir=args.output_dir,
strategy=strategy,
callbacks=callbacks,
logger=loggers,
profiler='simple',
)
if(args.resume_model_weights_only):
......@@ -621,7 +644,16 @@ if __name__ == "__main__":
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--gpus", type=int, default=None)
parser.add_argument("--max_epochs", type=int, default=None)
parser.add_argument("--precision", type=str, default="32")
parser.add_argument("--log_every_n_steps", type=int, default=50)
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
parser.add_argument("--flush_logs_every_n_steps", type=int, default=5)
parser.add_argument("--num_sanity_val_steps", type=int, default=0)
parser.add_argument("--mpi_plugin", action="store_true", default=False)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser.set_defaults(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment