Unverified Commit e938c184 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #534 from aqlaboratory/pl_upgrades

Update openfold to use pytorch 2 and other updated dependencies
parents 815a042c c587b06e
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04 FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
# metainformation # metainformation
LABEL org.opencontainers.image.version = "1.0.0" LABEL org.opencontainers.image.version = "2.0.0"
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz" LABEL org.opencontainers.image.authors = "OpenFold Team"
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold" LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
LABEL org.opencontainers.image.licenses = "Apache License 2.0" LABEL org.opencontainers.image.licenses = "Apache License 2.0"
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04" LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04"
RUN apt-get update && apt-get install -y wget
RUN apt-key del 7fa2af80 RUN apt-key del 7fa2af80
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub RUN dpkg -i cuda-keyring_1.0-1_all.deb
RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
......
...@@ -62,7 +62,7 @@ python3 run_pretrained_openfold.py \ ...@@ -62,7 +62,7 @@ python3 run_pretrained_openfold.py \
$TEMPLATE_MMCIF_DIR $TEMPLATE_MMCIF_DIR
--output_dir $OUTPUT_DIR \ --output_dir $OUTPUT_DIR \
--config_preset model_1_ptm \ --config_preset model_1_ptm \
--uniref90_database_path $BASE_DATA_DIR/uniref90 \ --uniref90_database_path $BASE_DATA_DIR/uniref90/uniref90.fasta \
--mgnify_database_path $BASE_DATA_DIR/mgnify/mgy_clusters_2018_12.fa \ --mgnify_database_path $BASE_DATA_DIR/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path $BASE_DATA_DIR/pdb70 \ --pdb70_database_path $BASE_DATA_DIR/pdb70 \
--uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
......
...@@ -4,7 +4,7 @@ In this guide, we will OpenFold and its dependencies. ...@@ -4,7 +4,7 @@ In this guide, we will OpenFold and its dependencies.
**Pre-requisites** **Pre-requisites**
This package is currently supported for CUDA 11 and Pytorch 1.12. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml). To install OpenFold for CUDA 12, please refer to the [Environment specific modifications](#Environment-specific-modifications) section. This package is currently supported for CUDA 12 and Pytorch 2. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml).
At this time, only Linux systems are supported. At this time, only Linux systems are supported.
...@@ -53,12 +53,6 @@ Certain tests perform equivalence comparisons with the AlphaFold implementation. ...@@ -53,12 +53,6 @@ Certain tests perform equivalence comparisons with the AlphaFold implementation.
## Environment specific modifications ## Environment specific modifications
### CUDA 12
To use OpenFold on CUDA 12 environment rather than a CUDA 11 environment.
In step 1, use the branch [`pl_upgrades`](https://github.com/aqlaboratory/openfold/tree/pl_upgrades) rather than the main branch, i.e. replace the command in step 1 with `git clone -b pl_upgrades https://github.com/aqlaboratory/openfold.git`
and follow the rest of the steps of [Installation Guide](#Installation)
### MPI ### MPI
To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`. To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`.
...@@ -71,4 +65,4 @@ If you don't have access to `aws` on your system, you can use a different downlo ...@@ -71,4 +65,4 @@ If you don't have access to `aws` on your system, you can use a different downlo
### Docker setup ### Docker setup
A [`Dockerfile`] is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container). A [`Dockerfile`](https://github.com/aqlaboratory/openfold/blob/main/Dockerfile) is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container).
...@@ -3,36 +3,38 @@ channels: ...@@ -3,36 +3,38 @@ channels:
- conda-forge - conda-forge
- bioconda - bioconda
- pytorch - pytorch
- nvidia
dependencies: dependencies:
- python=3.9 - cuda
- libgcc=7.2 - gcc=12.4
- python=3.10
- setuptools=59.5.0 - setuptools=59.5.0
- pip - pip
- openmm=7.7 - openmm
- pdbfixer - pdbfixer
- pytorch-lightning - pytorch-lightning
- biopython - biopython
- numpy - numpy
- pandas - pandas
- PyYAML==5.4.1 - PyYAML
- requests - requests
- scipy==1.7 - scipy
- tqdm==4.62.2 - tqdm
- typing-extensions==4.0 - typing-extensions
- wandb - wandb
- modelcif==0.7 - modelcif==0.7
- awscli - awscli
- ml-collections - ml-collections
- aria2 - aria2
- mkl=2024.0 - mkl
- git - git
- bioconda::hmmer==3.3.2 - bioconda::hmmer
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite
- bioconda::kalign2==2.04 - bioconda::kalign2
- bioconda::mmseqs2 - pytorch::pytorch=2.5
- pytorch::pytorch=1.12.* - pytorch::pytorch-cuda=12.4
- pip: - pip:
- deepspeed==0.12.4 - deepspeed==0.14.5
- dm-tree==0.1.6 - dm-tree==0.1.6
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 - flash-attn
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/aqlaboratory/OpenFold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
...@@ -107,11 +117,11 @@ ...@@ -107,11 +117,11 @@
"\n", "\n",
"python_version = f\"{version_info.major}.{version_info.minor}\"\n", "python_version = f\"{version_info.major}.{version_info.minor}\"\n",
"\n", "\n",
"\n", "os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh\")\n",
"os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n", "os.system(\"bash Miniforge3-Linux-x86_64.sh -bfp /usr/local\")\n",
"os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n", "os.environ[\"PATH\"] = \"/usr/local/bin:\" + os.environ[\"PATH\"]\n",
"os.system(\"mamba config --set auto_update_conda false\")\n", "os.system(\"mamba config --set auto_update_conda false\")\n",
"os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python={python_version} pdbfixer biopython=1.83\")\n", "os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=8.2.0 python={python_version} pdbfixer biopython=1.83\")\n",
"os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n", "os.system(\"pip install -q torch ml_collections py3Dmol modelcif\")\n",
"\n", "\n",
"try:\n", "try:\n",
...@@ -127,7 +137,7 @@ ...@@ -127,7 +137,7 @@
"\n", "\n",
" %shell mkdir -p /content/openfold/openfold/resources\n", " %shell mkdir -p /content/openfold/openfold/resources\n",
"\n", "\n",
" commit = \"3bec3e9b2d1e8bdb83887899102eff7d42dc2ba9\"\n", " commit = \"1ffd197489aa5f35a5fbce1f00d7dd49bce1bd2f\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
"\n", "\n",
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n", " os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
...@@ -893,8 +903,7 @@ ...@@ -893,8 +903,7 @@
"metadata": { "metadata": {
"colab": { "colab": {
"provenance": [], "provenance": [],
"gpuType": "T4", "gpuType": "T4"
"toc_visible": true
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3",
......
...@@ -660,7 +660,7 @@ config = mlc.ConfigDict( ...@@ -660,7 +660,7 @@ config = mlc.ConfigDict(
}, },
"relax": { "relax": {
"max_iterations": 0, # no max "max_iterations": 0, # no max
"tolerance": 2.39, "tolerance": 10.0,
"stiffness": 10.0, "stiffness": 10.0,
"max_outer_iterations": 20, "max_outer_iterations": 20,
"exclude_residues": [], "exclude_residues": [],
......
...@@ -28,7 +28,7 @@ if ds4s_is_installed: ...@@ -28,7 +28,7 @@ if ds4s_is_installed:
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
from flash_attn.bert_padding import unpad_input from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -808,10 +808,10 @@ def _flash_attn(q, k, v, kv_mask): ...@@ -808,10 +808,10 @@ def _flash_attn(q, k, v, kv_mask):
# [B_flat, N, 2 * H * C] # [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1) kv = kv.reshape(*kv.shape[:-3], -1)
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad, _, kv_cu_seqlens, kv_max_s, _ = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func( out = flash_attn_varlen_kvpacked_func(
q, q,
kv_unpad, kv_unpad,
q_cu_seqlens, q_cu_seqlens,
......
...@@ -34,6 +34,7 @@ from openmm import app as openmm_app ...@@ -34,6 +34,7 @@ from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure from openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole ENERGY = unit.kilocalories_per_mole
FORCE = unit.kilojoules_per_mole / unit.nanometer
LENGTH = unit.angstroms LENGTH = unit.angstroms
...@@ -439,7 +440,7 @@ def _run_one_iteration( ...@@ -439,7 +440,7 @@ def _run_one_iteration(
exclude_residues = exclude_residues or [] exclude_residues = exclude_residues or []
# Assign physical dimensions. # Assign physical dimensions.
tolerance = tolerance * ENERGY tolerance = tolerance * FORCE
stiffness = stiffness * ENERGY / (LENGTH ** 2) stiffness = stiffness * ENERGY / (LENGTH ** 2)
start = time.perf_counter() start = time.perf_counter()
......
...@@ -35,10 +35,10 @@ def _superimpose_np(reference, coords): ...@@ -35,10 +35,10 @@ def _superimpose_np(reference, coords):
def _superimpose_single(reference, coords): def _superimpose_single(reference, coords):
reference_np = reference.detach().to(torch.float).cpu().numpy() reference_np = reference.detach().to(torch.float).cpu().numpy()
coords_np = coords.detach().to(torch.float).cpu().numpy() coords_np = coords.detach().to(torch.float).cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np) superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd) return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords, mask): def superimpose(reference, coords, mask):
......
...@@ -14,7 +14,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats. ...@@ -14,7 +14,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install python setup.py install
echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel" echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1 git clone https://github.com/NVIDIA/cutlass --branch v3.6.0 --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass conda env config vars set CUTLASS_PATH=$PWD/cutlass
# This setting is used to fix a worker assignment issue during data loading # This setting is used to fix a worker assignment issue during data loading
......
...@@ -29,7 +29,7 @@ version_dependent_macros = [ ...@@ -29,7 +29,7 @@ version_dependent_macros = [
] ]
extra_cuda_flags = [ extra_cuda_flags = [
'-std=c++14', '-std=c++17',
'-maxrregcount=50', '-maxrregcount=50',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
...@@ -52,9 +52,9 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -52,9 +52,9 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_major, bare_metal_minor
compute_capabilities = set([ compute_capabilities = set([
(3, 7), # K80, e.g.
(5, 2), # Titan X (5, 2), # Titan X
(6, 1), # GeForce 1000-series (6, 1), # GeForce 1000-series
(9, 0), # Hopper
]) ])
compute_capabilities.add((7, 0)) compute_capabilities.add((7, 0))
...@@ -113,7 +113,7 @@ else: ...@@ -113,7 +113,7 @@ else:
setup( setup(
name='openfold', name='openfold',
version='2.0.0', version='2.2.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='OpenFold Team', author='OpenFold Team',
author_email='jennifer.wei@omsf.io', author_email='jennifer.wei@omsf.io',
...@@ -130,7 +130,7 @@ setup( ...@@ -130,7 +130,7 @@ setup(
classifiers=[ classifiers=[
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Operating System :: POSIX :: Linux', 'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.9,' 'Programming Language :: Python :: 3.10,'
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
], ],
) )
...@@ -306,7 +306,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -306,7 +306,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
# print(batch["target_feat"].shape)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32) batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update( batch.update(
...@@ -316,8 +315,9 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -316,8 +315,9 @@ class TestDeepSpeedKernel(unittest.TestCase):
# Move the recycling dimension to the end # Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch) batch = tensor_tree_map(move_dim, batch)
with torch.no_grad(): # Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16
with torch.cuda.amp.autocast(dtype=torch.bfloat16): # https://github.com/aqlaboratory/openfold/issues/532
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False model.globals.use_deepspeed_evo_attention = False
out_repro = model(batch) out_repro = model(batch)
......
...@@ -202,4 +202,4 @@ class TestModel(unittest.TestCase): ...@@ -202,4 +202,4 @@ class TestModel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0) out_repro = out_repro.squeeze(0)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 1e-3)
...@@ -21,7 +21,6 @@ from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataM ...@@ -21,7 +21,6 @@ from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataM
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
...@@ -55,7 +54,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -55,7 +54,7 @@ class OpenFoldWrapper(pl.LightningModule):
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay model=self.model, decay=config.ema.decay
) )
self.cached_weights = None self.cached_weights = None
self.last_lr_step = -1 self.last_lr_step = -1
self.save_hyperparameters() self.save_hyperparameters()
...@@ -73,7 +72,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -73,7 +72,7 @@ class OpenFoldWrapper(pl.LightningModule):
on_step=train, on_epoch=(not train), logger=True, sync_dist=False, on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
) )
if(train): if (train):
self.log( self.log(
f"{phase}/{loss_name}_epoch", f"{phase}/{loss_name}_epoch",
indiv_loss, indiv_loss,
...@@ -82,12 +81,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -82,12 +81,12 @@ class OpenFoldWrapper(pl.LightningModule):
with torch.no_grad(): with torch.no_grad():
other_metrics = self._compute_validation_metrics( other_metrics = self._compute_validation_metrics(
batch, batch,
outputs, outputs,
superimposition_metrics=(not train) superimposition_metrics=(not train)
) )
for k,v in other_metrics.items(): for k, v in other_metrics.items():
self.log( self.log(
f"{phase}/{k}", f"{phase}/{k}",
torch.mean(v), torch.mean(v),
...@@ -96,7 +95,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -96,7 +95,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device): if (self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None) ground_truth = batch.pop('gt_features', None)
...@@ -127,12 +126,13 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -127,12 +126,13 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if (self.cached_weights is None):
# model.state_dict() contains references to model weights rather # model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling # than copies. Therefore, we need to clone them before calling
# load_state_dict(). # load_state_dict().
clone_param = lambda t: t.detach().clone() def clone_param(t): return t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) self.cached_weights = tensor_tree_map(
clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None) ground_truth = batch.pop('gt_features', None)
...@@ -160,17 +160,17 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -160,17 +160,17 @@ class OpenFoldWrapper(pl.LightningModule):
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
def _compute_validation_metrics(self, def _compute_validation_metrics(self,
batch, batch,
outputs, outputs,
superimposition_metrics=False superimposition_metrics=False
): ):
metrics = {} metrics = {}
gt_coords = batch["all_atom_positions"] gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"] pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"] all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later # This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None] gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None]
...@@ -178,7 +178,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -178,7 +178,7 @@ class OpenFoldWrapper(pl.LightningModule):
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos] all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca( lddt_ca_score = lddt_ca(
pred_coords, pred_coords,
gt_coords, gt_coords,
...@@ -186,18 +186,18 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -186,18 +186,18 @@ class OpenFoldWrapper(pl.LightningModule):
eps=self.config.globals.eps, eps=self.config.globals.eps,
per_residue=False, per_residue=False,
) )
metrics["lddt_ca"] = lddt_ca_score metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd( drmsd_ca_score = drmsd(
pred_coords_masked_ca, pred_coords_masked_ca,
gt_coords_masked_ca, gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n mask=all_atom_mask_ca, # still required here to compute n
) )
metrics["drmsd_ca"] = drmsd_ca_score metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics): if (superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose( superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
) )
...@@ -211,7 +211,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -211,7 +211,7 @@ class OpenFoldWrapper(pl.LightningModule):
metrics["alignment_rmsd"] = alignment_rmsd metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score metrics["gdt_ha"] = gdt_ha_score
return metrics return metrics
def configure_optimizers(self, def configure_optimizers(self,
...@@ -220,8 +220,8 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -220,8 +220,8 @@ class OpenFoldWrapper(pl.LightningModule):
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
lr=learning_rate, lr=learning_rate,
eps=eps eps=eps
) )
...@@ -246,8 +246,9 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -246,8 +246,9 @@ class OpenFoldWrapper(pl.LightningModule):
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"] ema = checkpoint["ema"]
if(not self.model.template_config.enabled): if (not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} ema["params"] = {k: v for k,
v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema) self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
...@@ -258,13 +259,13 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -258,13 +259,13 @@ class OpenFoldWrapper(pl.LightningModule):
def load_from_jax(self, jax_path): def load_from_jax(self, jax_path):
model_basename = os.path.splitext( model_basename = os.path.splitext(
os.path.basename( os.path.basename(
os.path.normpath(jax_path) os.path.normpath(jax_path)
) )
)[0] )[0]
model_version = "_".join(model_basename.split("_")[1:]) model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_( import_jax_weights_(
self.model, jax_path, version=model_version self.model, jax_path, version=model_version
) )
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir): def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
...@@ -331,30 +332,31 @@ def main(args): ...@@ -331,30 +332,31 @@ def main(args):
if args.resume_from_jax_params: if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") logging.info(
f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model # TorchScript components of the model
if(args.script_modules): if (args.script_modules):
script_preset_(model_module) script_preset_(model_module)
if "multimer" in args.config_preset: if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule( data_module = OpenFoldMultimerDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
else: else:
data_module = OpenFoldDataModule( data_module = OpenFoldDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
callbacks = [] callbacks = []
if(args.checkpoint_every_epoch): if (args.checkpoint_every_epoch):
mc = ModelCheckpoint( mc = ModelCheckpoint(
every_n_epochs=1, every_n_epochs=1,
auto_insert_metric_name=False, auto_insert_metric_name=False,
...@@ -362,7 +364,7 @@ def main(args): ...@@ -362,7 +364,7 @@ def main(args):
) )
callbacks.append(mc) callbacks.append(mc)
if(args.early_stopping): if (args.early_stopping):
es = EarlyStoppingVerbose( es = EarlyStoppingVerbose(
monitor="val/lddt_ca", monitor="val/lddt_ca",
min_delta=args.min_delta, min_delta=args.min_delta,
...@@ -374,7 +376,7 @@ def main(args): ...@@ -374,7 +376,7 @@ def main(args):
) )
callbacks.append(es) callbacks.append(es)
if(args.log_performance): if (args.log_performance):
global_batch_size = args.num_nodes * args.gpus global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback( perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"), log_file=os.path.join(args.output_dir, "performance_log.json"),
...@@ -382,7 +384,7 @@ def main(args): ...@@ -382,7 +384,7 @@ def main(args):
) )
callbacks.append(perf) callbacks.append(perf)
if(args.log_lr): if (args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step") lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor) callbacks.append(lr_monitor)
...@@ -448,7 +450,7 @@ def main(args): ...@@ -448,7 +450,7 @@ def main(args):
ckpt_path = args.resume_from_ckpt ckpt_path = args.resume_from_ckpt
trainer.fit( trainer.fit(
model_module, model_module,
datamodule=data_module, datamodule=data_module,
ckpt_path=ckpt_path, ckpt_path=ckpt_path,
) )
...@@ -680,22 +682,22 @@ if __name__ == "__main__": ...@@ -680,22 +682,22 @@ if __name__ == "__main__":
trainer_group.add_argument( trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1, "--reload_dataloaders_every_n_epochs", type=int, default=1,
) )
trainer_group.add_argument(
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1, "--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.") help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args() args = parser.parse_args()
if(args.seed is None and if (args.seed is None and
((args.gpus is not None and args.gpus > 1) or ((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))): (args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified") raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None): if (str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible") raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment