Commit 721e43be authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add GPU relaxation

parent fe98aa32
FROM nvidia/cuda:11.0-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:11.0-cudnn8-runtime-ubuntu18.04
# I'm not sure why i needed both opencl and cuda here, but the relax phase of the script needed opencl RUN apt-get update && apt-get install -y wget cuda-minimal-build-11-0 git
RUN apt-get update && apt-get install -y wget cuda-minimal-build-11-0 nvidia-opencl-dev git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
......
...@@ -477,7 +477,9 @@ ...@@ -477,7 +477,9 @@
" tolerance=2.39,\n", " tolerance=2.39,\n",
" stiffness=10.0,\n", " stiffness=10.0,\n",
" exclude_residues=[],\n", " exclude_residues=[],\n",
" max_outer_iterations=20)\n", " max_outer_iterations=20,\n"
" use_gpu=True,\n"
" )\n",
" # Find the best model according to the mean pLDDT.\n", " # Find the best model according to the mean pLDDT.\n",
" best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n", " best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n",
" relaxed_pdb, _, _ = amber_relaxer.process(\n", " relaxed_pdb, _, _ = amber_relaxer.process(\n",
......
from . import model from . import model
from . import utils from . import utils
from . import np from . import np
from . import resources
__all__ = ["model", "utils", "np", "data"] __all__ = ["model", "utils", "np", "data", "resources"]
...@@ -127,10 +127,10 @@ class MSAAttention(nn.Module): ...@@ -127,10 +127,10 @@ class MSAAttention(nn.Module):
): ):
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
# [*, N_res, N_res, no_heads] # [*, N_res, N_res, no_heads]
z = self.linear_z(z) z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
......
...@@ -83,6 +83,7 @@ def _openmm_minimize( ...@@ -83,6 +83,7 @@ def _openmm_minimize(
stiffness: unit.Unit, stiffness: unit.Unit,
restraint_set: str, restraint_set: str,
exclude_residues: Sequence[int], exclude_residues: Sequence[int],
use_gpu: bool,
): ):
"""Minimize energy via openmm.""" """Minimize energy via openmm."""
...@@ -96,7 +97,7 @@ def _openmm_minimize( ...@@ -96,7 +97,7 @@ def _openmm_minimize(
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues) _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CPU") platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
simulation = openmm_app.Simulation( simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform pdb.topology, system, integrator, platform
) )
...@@ -411,6 +412,7 @@ def _run_one_iteration( ...@@ -411,6 +412,7 @@ def _run_one_iteration(
restraint_set: str, restraint_set: str,
max_attempts: int, max_attempts: int,
exclude_residues: Optional[Collection[int]] = None, exclude_residues: Optional[Collection[int]] = None,
use_gpu: bool,
): ):
"""Runs the minimization pipeline. """Runs the minimization pipeline.
...@@ -425,7 +427,7 @@ def _run_one_iteration( ...@@ -425,7 +427,7 @@ def _run_one_iteration(
max_attempts: The maximum number of minimization attempts. max_attempts: The maximum number of minimization attempts.
exclude_residues: An optional list of zero-indexed residues to exclude from exclude_residues: An optional list of zero-indexed residues to exclude from
restraints. restraints.
use_gpu: Whether to run relaxation on GPU
Returns: Returns:
A `dict` of minimization info. A `dict` of minimization info.
""" """
...@@ -451,9 +453,11 @@ def _run_one_iteration( ...@@ -451,9 +453,11 @@ def _run_one_iteration(
stiffness=stiffness, stiffness=stiffness,
restraint_set=restraint_set, restraint_set=restraint_set,
exclude_residues=exclude_residues, exclude_residues=exclude_residues,
use_gpu=use_gpu,
) )
minimized = True minimized = True
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
print(e)
logging.info(e) logging.info(e)
if not minimized: if not minimized:
raise ValueError(f"Minimization failed after {max_attempts} attempts.") raise ValueError(f"Minimization failed after {max_attempts} attempts.")
...@@ -465,6 +469,7 @@ def _run_one_iteration( ...@@ -465,6 +469,7 @@ def _run_one_iteration(
def run_pipeline( def run_pipeline(
prot: protein.Protein, prot: protein.Protein,
stiffness: float, stiffness: float,
use_gpu: bool,
max_outer_iterations: int = 1, max_outer_iterations: int = 1,
place_hydrogens_every_iteration: bool = True, place_hydrogens_every_iteration: bool = True,
max_iterations: int = 0, max_iterations: int = 0,
...@@ -483,6 +488,7 @@ def run_pipeline( ...@@ -483,6 +488,7 @@ def run_pipeline(
Args: Args:
prot: A protein to be relaxed. prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness. stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU
max_outer_iterations: The maximum number of iterative minimization. max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization. prior to every minimization.
...@@ -519,6 +525,7 @@ def run_pipeline( ...@@ -519,6 +525,7 @@ def run_pipeline(
stiffness=stiffness, stiffness=stiffness,
restraint_set=restraint_set, restraint_set=restraint_set,
max_attempts=max_attempts, max_attempts=max_attempts,
use_gpu=use_gpu,
) )
prot = protein.from_pdb_string(ret["min_pdb"]) prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration: if place_hydrogens_every_iteration:
......
...@@ -22,7 +22,6 @@ import numpy as np ...@@ -22,7 +22,6 @@ import numpy as np
class AmberRelaxation(object): class AmberRelaxation(object):
"""Amber relaxation.""" """Amber relaxation."""
def __init__( def __init__(
self, self,
*, *,
...@@ -30,7 +29,8 @@ class AmberRelaxation(object): ...@@ -30,7 +29,8 @@ class AmberRelaxation(object):
tolerance: float, tolerance: float,
stiffness: float, stiffness: float,
exclude_residues: Sequence[int], exclude_residues: Sequence[int],
max_outer_iterations: int max_outer_iterations: int,
use_gpu: bool,
): ):
"""Initialize Amber Relaxer. """Initialize Amber Relaxer.
...@@ -46,6 +46,7 @@ class AmberRelaxation(object): ...@@ -46,6 +46,7 @@ class AmberRelaxation(object):
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations. slowdown. In the worst case we do 20 outer iterations.
use_gpu: Whether to run on GPU
""" """
self._max_iterations = max_iterations self._max_iterations = max_iterations
...@@ -53,6 +54,7 @@ class AmberRelaxation(object): ...@@ -53,6 +54,7 @@ class AmberRelaxation(object):
self._stiffness = stiffness self._stiffness = stiffness
self._exclude_residues = exclude_residues self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations self._max_outer_iterations = max_outer_iterations
self._use_gpu = use_gpu
def process( def process(
self, *, prot: protein.Protein self, *, prot: protein.Protein
...@@ -65,6 +67,7 @@ class AmberRelaxation(object): ...@@ -65,6 +67,7 @@ class AmberRelaxation(object):
stiffness=self._stiffness, stiffness=self._stiffness,
exclude_residues=self._exclude_residues, exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations, max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
) )
min_pos = out["pos"] min_pos = out["pos"]
start_pos = out["posinit"] start_pos = out["posinit"]
......
...@@ -19,9 +19,6 @@ import logging ...@@ -19,9 +19,6 @@ import logging
import numpy as np import numpy as np
import os import os
# A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import pickle import pickle
import random import random
import sys import sys
...@@ -152,19 +149,32 @@ def main(args): ...@@ -152,19 +149,32 @@ def main(args):
result=out, result=out,
b_factors=plddt_b_factors b_factors=plddt_b_factors
) )
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
**config.relax use_gpu=(args.model_device != "cpu"),
**config.relax,
) )
# Relax the prediction. # Relax the prediction.
t = time.perf_counter() t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
logging.info(f"Relaxation time: {time.perf_counter() - t}") logging.info(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}.pdb' args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
) )
with open(relaxed_output_path, 'w') as f: with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str) f.write(relaxed_pdb_str)
...@@ -175,7 +185,9 @@ if __name__ == "__main__": ...@@ -175,7 +185,9 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"fasta_path", type=str, "fasta_path", type=str,
) )
add_data_args(parser) parser.add_argument(
"template_mmcif_dir", type=str,
)
parser.add_argument( parser.add_argument(
"--use_precomputed_alignments", type=str, default=None, "--use_precomputed_alignments", type=str, default=None,
help="""Path to alignment directory. If provided, alignment computation help="""Path to alignment directory. If provided, alignment computation
...@@ -184,7 +196,6 @@ if __name__ == "__main__": ...@@ -184,7 +196,6 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default=os.getcwd(), "--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""", help="""Name of the directory in which to output the prediction""",
required=True
) )
parser.add_argument( parser.add_argument(
"--model_device", type=str, default="cpu", "--model_device", type=str, default="cpu",
...@@ -213,7 +224,7 @@ if __name__ == "__main__": ...@@ -213,7 +224,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'--data_random_seed', type=str, default=None '--data_random_seed', type=str, default=None
) )
add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
if(args.param_path is None): if(args.param_path is None):
......
...@@ -12,9 +12,6 @@ def add_data_args(parser: argparse.ArgumentParser): ...@@ -12,9 +12,6 @@ def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
'--pdb70_database_path', type=str, default=None, '--pdb70_database_path', type=str, default=None,
) )
parser.add_argument(
'--template_mmcif_dir', type=str, default=None,
)
parser.add_argument( parser.add_argument(
'--uniclust30_database_path', type=str, default=None, '--uniclust30_database_path', type=str, default=None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment