Commit 57a2455e authored by Augustin Zidek's avatar Augustin Zidek
Browse files

Release code for v2.1.2

PiperOrigin-RevId: 424591903
parent db039ef4
...@@ -149,7 +149,7 @@ parameters are made available under the terms of the CC BY 4.0 license. Please ...@@ -149,7 +149,7 @@ parameters are made available under the terms of the CC BY 4.0 license. Please
see the [Disclaimer](#license-and-disclaimer) below for more detail. see the [Disclaimer](#license-and-disclaimer) below for more detail.
The AlphaFold parameters are available from The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar, and https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for: will download parameters for:
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import collections import collections
import functools import functools
import re
import string import string
from typing import Any, Dict, Iterable, List, Sequence from typing import Any, Dict, Iterable, List, Sequence
...@@ -58,14 +57,6 @@ TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions', ...@@ -58,14 +57,6 @@ TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
CHAIN_FEATURES = ('num_alignments', 'seq_length') CHAIN_FEATURES = ('num_alignments', 'seq_length')
domain_name_pattern = re.compile(
r'''^(?P<pdb>[a-z\d]{4})
\{(?P<bioassembly>[\d+(\+\d+)?])\}
(?P<chain>[a-zA-Z\d]+)
\{(?P<transform_index>\d+)\}$
''', re.VERBOSE)
def create_paired_features( def create_paired_features(
chains: Iterable[pipeline.FeatureDict], chains: Iterable[pipeline.FeatureDict],
prokaryotic: bool, prokaryotic: bool,
...@@ -618,6 +609,7 @@ def deduplicate_unpaired_sequences( ...@@ -618,6 +609,7 @@ def deduplicate_unpaired_sequences(
msa_features = MSA_FEATURES msa_features = MSA_FEATURES
for chain in np_chains: for chain in np_chains:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set = set(tuple(s) for s in chain['msa_all_seq']) sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = [] keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the # Go through unpaired MSA seqs and remove any rows that correspond to the
...@@ -627,12 +619,6 @@ def deduplicate_unpaired_sequences( ...@@ -627,12 +619,6 @@ def deduplicate_unpaired_sequences(
keep_rows.append(row_num) keep_rows.append(row_num)
for feature_name in feature_names: for feature_name in feature_names:
if feature_name in msa_features: if feature_name in msa_features:
if keep_rows: chain[feature_name] = chain[feature_name][keep_rows]
chain[feature_name] = chain[feature_name][keep_rows]
else:
new_shape = list(chain[feature_name].shape)
new_shape[0] = 0
chain[feature_name] = np.zeros(new_shape,
dtype=chain[feature_name].dtype)
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32) chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains return np_chains
...@@ -20,6 +20,9 @@ import re ...@@ -20,6 +20,9 @@ import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
# Internal import (7716).
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
...@@ -271,24 +274,27 @@ def _keep_line(line: str, seqnames: Set[str]) -> bool: ...@@ -271,24 +274,27 @@ def _keep_line(line: str, seqnames: Set[str]) -> bool:
return seqname in seqnames return seqname in seqnames
def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
"""Truncates a stockholm file to a maximum number of sequences.""" """Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames = set() seqnames = set()
filtered_lines = [] filtered_lines = []
for line in stockholm_msa.splitlines():
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)
return '\n'.join(filtered_lines) + '\n' with open(stockholm_msa_path) as f:
for line in f:
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break
f.seek(0)
for line in f:
if _keep_line(line, seqnames):
filtered_lines.append(line)
return ''.join(filtered_lines)
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
......
...@@ -91,16 +91,25 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: ...@@ -91,16 +91,25 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str, def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
msa_format: str, use_precomputed_msas: bool, msa_format: str, use_precomputed_msas: bool,
max_sto_sequences: Optional[int] = None
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first.""" """Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path): if not use_precomputed_msas or not os.path.exists(msa_out_path):
result = msa_runner.query(input_fasta_path)[0] if msa_format == 'sto' and max_sto_sequences is not None:
result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count
else:
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f: with open(msa_out_path, 'w') as f:
f.write(result[msa_format]) f.write(result[msa_format])
else: else:
logging.warning('Reading MSA from file %s', msa_out_path) logging.warning('Reading MSA from file %s', msa_out_path)
with open(msa_out_path, 'r') as f: if msa_format == 'sto' and max_sto_sequences is not None:
result = {msa_format: f.read()} precomputed_msa = parsers.truncate_stockholm_msa(
msa_out_path, max_sto_sequences)
result = {'sto': precomputed_msa}
else:
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result return result
...@@ -157,18 +166,23 @@ class DataPipeline: ...@@ -157,18 +166,23 @@ class DataPipeline:
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool( jackhmmer_uniref90_result = run_msa_tool(
self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, msa_runner=self.jackhmmer_uniref90_runner,
'sto', self.use_precomputed_msas) input_fasta_path=input_fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.uniref_max_hits)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool( jackhmmer_mgnify_result = run_msa_tool(
self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', msa_runner=self.jackhmmer_mgnify_runner,
self.use_precomputed_msas) input_fasta_path=input_fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.mgnify_max_hits)
msa_for_templates = jackhmmer_uniref90_result['sto'] msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.truncate_stockholm_msa( msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
msa_for_templates, max_sequences=self.uniref_max_hits)
msa_for_templates = parsers.deduplicate_stockholm_msa(
msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
msa_for_templates) msa_for_templates)
...@@ -187,9 +201,7 @@ class DataPipeline: ...@@ -187,9 +201,7 @@ class DataPipeline:
f.write(pdb_templates_result) f.write(pdb_templates_result)
uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto']) uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
pdb_template_hits = self.template_searcher.get_template_hits( pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence) output_string=pdb_templates_result, input_sequence=input_sequence)
...@@ -197,14 +209,20 @@ class DataPipeline: ...@@ -197,14 +209,20 @@ class DataPipeline:
if self._use_small_bfd: if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool( jackhmmer_small_bfd_result = run_msa_tool(
self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path, msa_runner=self.jackhmmer_small_bfd_runner,
'sto', self.use_precomputed_msas) input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else: else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool( hhblits_bfd_uniclust_result = run_msa_tool(
self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path, msa_runner=self.hhblits_bfd_uniclust_runner,
'a3m', self.use_precomputed_msas) input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
......
...@@ -23,6 +23,7 @@ from urllib import request ...@@ -23,6 +23,7 @@ from urllib import request
from absl import logging from absl import logging
from alphafold.data import parsers
from alphafold.data.tools import utils from alphafold.data.tools import utils
# Internal import (7716). # Internal import (7716).
...@@ -86,8 +87,10 @@ class Jackhmmer: ...@@ -86,8 +87,10 @@ class Jackhmmer:
self.get_tblout = get_tblout self.get_tblout = get_tblout
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
def _query_chunk(self, input_fasta_path: str, database_path: str def _query_chunk(self,
) -> Mapping[str, Any]: input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer.""" """Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager() as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto') sto_path = os.path.join(query_tmp_dir, 'output.sto')
...@@ -145,8 +148,11 @@ class Jackhmmer: ...@@ -145,8 +148,11 @@ class Jackhmmer:
with open(tblout_path) as f: with open(tblout_path) as f:
tbl = f.read() tbl = f.read()
with open(sto_path) as f: if max_sequences is None:
sto = f.read() with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
raw_output = dict( raw_output = dict(
sto=sto, sto=sto,
...@@ -157,10 +163,14 @@ class Jackhmmer: ...@@ -157,10 +163,14 @@ class Jackhmmer:
return raw_output return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer.""" """Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None: if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)] single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences)
return [single_chunk_result]
db_basename = os.path.basename(self.database_path) db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
...@@ -187,8 +197,8 @@ class Jackhmmer: ...@@ -187,8 +197,8 @@ class Jackhmmer:
# Run Jackhmmer with the chunk # Run Jackhmmer with the chunk
future.result() future.result()
chunked_output.append( chunked_output.append(self._query_chunk(
self._query_chunk(input_fasta_path, db_local_chunk(i))) input_fasta_path, db_local_chunk(i), max_sequences))
# Remove the local copy of the chunk # Remove the local copy of the chunk
os.remove(db_local_chunk(i)) os.remove(db_local_chunk(i))
......
...@@ -186,7 +186,7 @@ class PointProjection(hk.Module): ...@@ -186,7 +186,7 @@ class PointProjection(hk.Module):
class InvariantPointAttention(hk.Module): class InvariantPointAttention(hk.Module):
"""Covariant attention module. """Invariant point attention module.
The high-level idea is that this attention module works over a set of points The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues). and associated orientations in 3D space (e.g. protein residues).
......
...@@ -76,7 +76,8 @@ def _openmm_minimize( ...@@ -76,7 +76,8 @@ def _openmm_minimize(
tolerance: unit.Unit, tolerance: unit.Unit,
stiffness: unit.Unit, stiffness: unit.Unit,
restraint_set: str, restraint_set: str,
exclude_residues: Sequence[int]): exclude_residues: Sequence[int],
use_gpu: bool):
"""Minimize energy via openmm.""" """Minimize energy via openmm."""
pdb_file = io.StringIO(pdb_str) pdb_file = io.StringIO(pdb_str)
...@@ -90,7 +91,7 @@ def _openmm_minimize( ...@@ -90,7 +91,7 @@ def _openmm_minimize(
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues) _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0) integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CPU") platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
simulation = openmm_app.Simulation( simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform) pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions) simulation.context.setPositions(pdb.positions)
...@@ -371,6 +372,7 @@ def _run_one_iteration( ...@@ -371,6 +372,7 @@ def _run_one_iteration(
stiffness: float, stiffness: float,
restraint_set: str, restraint_set: str,
max_attempts: int, max_attempts: int,
use_gpu: bool,
exclude_residues: Optional[Collection[int]] = None): exclude_residues: Optional[Collection[int]] = None):
"""Runs the minimization pipeline. """Runs the minimization pipeline.
...@@ -383,6 +385,7 @@ def _run_one_iteration( ...@@ -383,6 +385,7 @@ def _run_one_iteration(
potential. potential.
restraint_set: The set of atoms to restrain. restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts. max_attempts: The maximum number of minimization attempts.
use_gpu: Whether to run on GPU.
exclude_residues: An optional list of zero-indexed residues to exclude from exclude_residues: An optional list of zero-indexed residues to exclude from
restraints. restraints.
...@@ -407,7 +410,8 @@ def _run_one_iteration( ...@@ -407,7 +410,8 @@ def _run_one_iteration(
pdb_string, max_iterations=max_iterations, pdb_string, max_iterations=max_iterations,
tolerance=tolerance, stiffness=stiffness, tolerance=tolerance, stiffness=stiffness,
restraint_set=restraint_set, restraint_set=restraint_set,
exclude_residues=exclude_residues) exclude_residues=exclude_residues,
use_gpu=use_gpu)
minimized = True minimized = True
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
logging.info(e) logging.info(e)
...@@ -421,6 +425,7 @@ def _run_one_iteration( ...@@ -421,6 +425,7 @@ def _run_one_iteration(
def run_pipeline( def run_pipeline(
prot: protein.Protein, prot: protein.Protein,
stiffness: float, stiffness: float,
use_gpu: bool,
max_outer_iterations: int = 1, max_outer_iterations: int = 1,
place_hydrogens_every_iteration: bool = True, place_hydrogens_every_iteration: bool = True,
max_iterations: int = 0, max_iterations: int = 0,
...@@ -438,6 +443,7 @@ def run_pipeline( ...@@ -438,6 +443,7 @@ def run_pipeline(
Args: Args:
prot: A protein to be relaxed. prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness. stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU.
max_outer_iterations: The maximum number of iterative minimization. max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization. prior to every minimization.
...@@ -473,7 +479,8 @@ def run_pipeline( ...@@ -473,7 +479,8 @@ def run_pipeline(
tolerance=tolerance, tolerance=tolerance,
stiffness=stiffness, stiffness=stiffness,
restraint_set=restraint_set, restraint_set=restraint_set,
max_attempts=max_attempts) max_attempts=max_attempts,
use_gpu=use_gpu)
prot = protein.from_pdb_string(ret["min_pdb"]) prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration: if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True) pdb_string = clean_protein(prot, checks=True)
......
...@@ -21,6 +21,8 @@ from alphafold.relax import amber_minimize ...@@ -21,6 +21,8 @@ from alphafold.relax import amber_minimize
import numpy as np import numpy as np
# Internal import (7716). # Internal import (7716).
_USE_GPU = False
def _load_test_protein(data_path): def _load_test_protein(data_path):
pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path) pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
...@@ -35,7 +37,7 @@ class AmberMinimizeTest(absltest.TestCase): ...@@ -35,7 +37,7 @@ class AmberMinimizeTest(absltest.TestCase):
'alphafold/relax/testdata/multiple_disulfides_target.pdb' 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
) )
ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1, ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
stiffness=10.) stiffness=10., use_gpu=_USE_GPU)
self.assertIn('opt_time', ret) self.assertIn('opt_time', ret)
self.assertIn('min_attempts', ret) self.assertIn('min_attempts', ret)
...@@ -50,7 +52,8 @@ class AmberMinimizeTest(absltest.TestCase): ...@@ -50,7 +52,8 @@ class AmberMinimizeTest(absltest.TestCase):
' residues. This protein contains at least one residue with no atoms.'): ' residues. This protein contains at least one residue with no atoms.'):
amber_minimize.run_pipeline(prot, max_iterations=10, amber_minimize.run_pipeline(prot, max_iterations=10,
stiffness=1., stiffness=1.,
max_attempts=1) max_attempts=1,
use_gpu=_USE_GPU)
def test_iterative_relax(self): def test_iterative_relax(self):
prot = _load_test_protein( prot = _load_test_protein(
...@@ -59,7 +62,7 @@ class AmberMinimizeTest(absltest.TestCase): ...@@ -59,7 +62,7 @@ class AmberMinimizeTest(absltest.TestCase):
violations = amber_minimize.get_violation_metrics(prot) violations = amber_minimize.get_violation_metrics(prot)
self.assertGreater(violations['num_residue_violations'], 0) self.assertGreater(violations['num_residue_violations'], 0)
out = amber_minimize.run_pipeline( out = amber_minimize.run_pipeline(
prot=prot, max_outer_iterations=10, stiffness=10.) prot=prot, max_outer_iterations=10, stiffness=10., use_gpu=_USE_GPU)
self.assertLess(out['efinal'], out['einit']) self.assertLess(out['efinal'], out['einit'])
self.assertEqual(0, out['num_residue_violations']) self.assertEqual(0, out['num_residue_violations'])
......
...@@ -29,7 +29,8 @@ class AmberRelaxation(object): ...@@ -29,7 +29,8 @@ class AmberRelaxation(object):
tolerance: float, tolerance: float,
stiffness: float, stiffness: float,
exclude_residues: Sequence[int], exclude_residues: Sequence[int],
max_outer_iterations: int): max_outer_iterations: int,
use_gpu: bool):
"""Initialize Amber Relaxer. """Initialize Amber Relaxer.
Args: Args:
...@@ -44,6 +45,7 @@ class AmberRelaxation(object): ...@@ -44,6 +45,7 @@ class AmberRelaxation(object):
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations. slowdown. In the worst case we do 20 outer iterations.
use_gpu: Whether to run on GPU.
""" """
self._max_iterations = max_iterations self._max_iterations = max_iterations
...@@ -51,6 +53,7 @@ class AmberRelaxation(object): ...@@ -51,6 +53,7 @@ class AmberRelaxation(object):
self._stiffness = stiffness self._stiffness = stiffness
self._exclude_residues = exclude_residues self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations self._max_outer_iterations = max_outer_iterations
self._use_gpu = use_gpu
def process(self, *, def process(self, *,
prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
...@@ -59,7 +62,8 @@ class AmberRelaxation(object): ...@@ -59,7 +62,8 @@ class AmberRelaxation(object):
prot=prot, max_iterations=self._max_iterations, prot=prot, max_iterations=self._max_iterations,
tolerance=self._tolerance, stiffness=self._stiffness, tolerance=self._tolerance, stiffness=self._stiffness,
exclude_residues=self._exclude_residues, exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations) max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu)
min_pos = out['pos'] min_pos = out['pos']
start_pos = out['posinit'] start_pos = out['posinit']
rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0]) rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
......
...@@ -34,7 +34,8 @@ class RunAmberRelaxTest(absltest.TestCase): ...@@ -34,7 +34,8 @@ class RunAmberRelaxTest(absltest.TestCase):
'tolerance': 2.39, 'tolerance': 2.39,
'stiffness': 10.0, 'stiffness': 10.0,
'exclude_residues': [], 'exclude_residues': [],
'max_outer_iterations': 1} 'max_outer_iterations': 1,
'use_gpu': False}
def test_process(self): def test_process(self):
amber_relax = relax.AmberRelaxation(**self.test_config) amber_relax = relax.AmberRelaxation(**self.test_config)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
ARG CUDA=11.0 ARG CUDA=11.1
FROM nvidia/cuda:${CUDA}-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:${CUDA}-cudnn8-runtime-ubuntu18.04
# FROM directive resets ARGS, so we specify again (the value is retained if # FROM directive resets ARGS, so we specify again (the value is retained if
# previously set). # previously set).
...@@ -65,13 +65,16 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \ ...@@ -65,13 +65,16 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \
# Install pip packages. # Install pip packages.
RUN pip3 install --upgrade pip \ RUN pip3 install --upgrade pip \
&& pip3 install -r /app/alphafold/requirements.txt \ && pip3 install -r /app/alphafold/requirements.txt \
&& pip3 install --upgrade jax jaxlib==0.1.69+cuda${CUDA/./} -f \ && pip3 install --upgrade jax==0.2.14 jaxlib==0.1.69+cuda${CUDA/./} -f \
https://storage.googleapis.com/jax-releases/jax_releases.html https://storage.googleapis.com/jax-releases/jax_releases.html
# Apply OpenMM patch. # Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.7/site-packages WORKDIR /opt/conda/lib/python3.7/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch RUN patch -p0 < /app/alphafold/docker/openmm.patch
# Add SETUID bit to the ldconfig binary so that non-root users can run it.
RUN chmod u+s /sbin/ldconfig.real
# We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk
# with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for # with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for
# details. # details.
......
...@@ -28,6 +28,14 @@ from docker import types ...@@ -28,6 +28,14 @@ from docker import types
flags.DEFINE_bool( flags.DEFINE_bool(
'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') 'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
flags.DEFINE_boolean(
'run_relax', True,
'Whether to run the final relaxation step on the predicted models. Turning '
'relax off might result in predictions with distracting stereochemical '
'violations but might help in case you are having issues with the '
'relaxation stage.')
flags.DEFINE_bool(
'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.')
flags.DEFINE_string( flags.DEFINE_string(
'gpu_devices', 'all', 'gpu_devices', 'all',
'Comma separated list of devices to pass to NVIDIA_VISIBLE_DEVICES.') 'Comma separated list of devices to pass to NVIDIA_VISIBLE_DEVICES.')
...@@ -72,8 +80,17 @@ flags.DEFINE_boolean( ...@@ -72,8 +80,17 @@ flags.DEFINE_boolean(
'for inferencing many proteins.') 'for inferencing many proteins.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
'use_precomputed_msas', False, 'use_precomputed_msas', False,
'Whether to read MSAs that have been written to disk. WARNING: This will ' 'Whether to read MSAs that have been written to disk instead of running '
'not check if the sequence, database or configuration have changed.') 'the MSA tools. The MSA files are looked up in the output directory, so it '
'must stay the same between multiple runs that are to reuse the MSAs. '
'WARNING: This will not check if the sequence, database or configuration '
'have changed.')
flags.DEFINE_string(
'docker_user', f'{os.geteuid()}:{os.getegid()}',
'UID:GID with which to run the Docker container. The output directories '
'will be owned by this user:group. By default, this is the current user. '
'Valid options are: uid or uid:gid, non-numeric values are not recognised '
'by Docker unless that user has been created within the container.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -84,6 +101,9 @@ def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]: ...@@ -84,6 +101,9 @@ def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]:
path = os.path.abspath(path) path = os.path.abspath(path)
source_path = os.path.dirname(path) source_path = os.path.dirname(path)
target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name) target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name)
if not os.path.exists(source_path):
raise ValueError(f'Failed to find source directory "{source_path}" to '
'mount in Docker container.')
logging.info('Mounting %s -> %s', source_path, target_path) logging.info('Mounting %s -> %s', source_path, target_path)
mount = types.Mount(target_path, source_path, type='bind', read_only=True) mount = types.Mount(target_path, source_path, type='bind', read_only=True)
return mount, os.path.join(target_path, os.path.basename(path)) return mount, os.path.join(target_path, os.path.basename(path))
...@@ -184,6 +204,8 @@ def main(argv): ...@@ -184,6 +204,8 @@ def main(argv):
output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output') output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output')
mounts.append(types.Mount(output_target_path, FLAGS.output_dir, type='bind')) mounts.append(types.Mount(output_target_path, FLAGS.output_dir, type='bind'))
use_gpu_relax = FLAGS.enable_gpu_relax and FLAGS.use_gpu
command_args.extend([ command_args.extend([
f'--output_dir={output_target_path}', f'--output_dir={output_target_path}',
f'--max_template_date={FLAGS.max_template_date}', f'--max_template_date={FLAGS.max_template_date}',
...@@ -191,6 +213,8 @@ def main(argv): ...@@ -191,6 +213,8 @@ def main(argv):
f'--model_preset={FLAGS.model_preset}', f'--model_preset={FLAGS.model_preset}',
f'--benchmark={FLAGS.benchmark}', f'--benchmark={FLAGS.benchmark}',
f'--use_precomputed_msas={FLAGS.use_precomputed_msas}', f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
f'--run_relax={FLAGS.run_relax}',
f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr', '--logtostderr',
]) ])
...@@ -206,6 +230,7 @@ def main(argv): ...@@ -206,6 +230,7 @@ def main(argv):
remove=True, remove=True,
detach=True, detach=True,
mounts=mounts, mounts=mounts,
user=FLAGS.docker_user,
environment={ environment={
'NVIDIA_VISIBLE_DEVICES': FLAGS.gpu_devices, 'NVIDIA_VISIBLE_DEVICES': FLAGS.gpu_devices,
# The following flags allow us to make predictions on proteins that # The following flags allow us to make predictions on proteins that
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"\n", "\n",
"In comparison to AlphaFold v2.1.0, this Colab notebook uses **no templates (homologous structures)** and a selected portion of the [BFD database](https://bfd.mmseqs.com/). We have validated these changes on several thousand recent PDB structures. While accuracy will be near-identical to the full AlphaFold system on many targets, a small fraction have a large drop in accuracy due to the smaller MSA and lack of templates. For best reliability, we recommend instead using the [full open source AlphaFold](https://github.com/deepmind/alphafold/), or the [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).\n", "In comparison to AlphaFold v2.1.0, this Colab notebook uses **no templates (homologous structures)** and a selected portion of the [BFD database](https://bfd.mmseqs.com/). We have validated these changes on several thousand recent PDB structures. While accuracy will be near-identical to the full AlphaFold system on many targets, a small fraction have a large drop in accuracy due to the smaller MSA and lack of templates. For best reliability, we recommend instead using the [full open source AlphaFold](https://github.com/deepmind/alphafold/), or the [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).\n",
"\n", "\n",
"**This Colab has an small drop in average accuracy for multimers compared to local AlphaFold installation, for full multimer accuracy it is highly recommended to run [AlphaFold locally](https://github.com/deepmind/alphafold#running-alphafold).** Moreover, the AlphaFold-Multimer requires searching for MSA for every unique sequence in the complex, hence it is substantially slower. If your notebook times-out due to slow multimer MSA search, we recommend either using Colab Pro or running AlphaFold locally.\n", "**This Colab has a small drop in average accuracy for multimers compared to local AlphaFold installation, for full multimer accuracy it is highly recommended to run [AlphaFold locally](https://github.com/deepmind/alphafold#running-alphafold).** Moreover, the AlphaFold-Multimer requires searching for MSA for every unique sequence in the complex, hence it is substantially slower. If your notebook times-out due to slow multimer MSA search, we recommend either using Colab Pro or running AlphaFold locally.\n",
"\n", "\n",
"Please note that this Colab notebook is provided as an early-access prototype and is not a finished product. It is provided for theoretical modelling only and caution should be exercised in its use. \n", "Please note that this Colab notebook is provided as an early-access prototype and is not a finished product. It is provided for theoretical modelling only and caution should be exercised in its use. \n",
"\n", "\n",
...@@ -37,6 +37,17 @@ ...@@ -37,6 +37,17 @@
"FAQ on how to interpret AlphaFold predictions are [here](https://alphafold.ebi.ac.uk/faq)." "FAQ on how to interpret AlphaFold predictions are [here](https://alphafold.ebi.ac.uk/faq)."
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "uC1dKAwk2eyl"
},
"source": [
"## Setup\n",
"\n",
"Start by running the 2 cells below to set up AlphaFold and all required software."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -46,7 +57,7 @@ ...@@ -46,7 +57,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Install third-party software\n", "#@title 1. Install third-party software\n",
"\n", "\n",
"#@markdown Please execute this cell by pressing the _Play_ button \n", "#@markdown Please execute this cell by pressing the _Play_ button \n",
"#@markdown on the left to download and import third-party software \n", "#@markdown on the left to download and import third-party software \n",
...@@ -114,7 +125,7 @@ ...@@ -114,7 +125,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Download AlphaFold\n", "#@title 2. Download AlphaFold\n",
"\n", "\n",
"#@markdown Please execute this cell by pressing the *Play* button on \n", "#@markdown Please execute this cell by pressing the *Play* button on \n",
"#@markdown the left.\n", "#@markdown the left.\n",
...@@ -201,7 +212,7 @@ ...@@ -201,7 +212,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Enter the amino acid sequence(s) to fold ⬇️\n", "#@title 3. Enter the amino acid sequence(s) to fold ⬇️\n",
"#@markdown Enter the amino acid sequence(s) to fold:\n", "#@markdown Enter the amino acid sequence(s) to fold:\n",
"#@markdown * If you enter only a single sequence, the monomer model will be used.\n", "#@markdown * If you enter only a single sequence, the monomer model will be used.\n",
"#@markdown * If you enter multiple sequences, the multimer model will be used.\n", "#@markdown * If you enter multiple sequences, the multimer model will be used.\n",
...@@ -247,7 +258,7 @@ ...@@ -247,7 +258,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Search against genetic databases\n", "#@title 4. Search against genetic databases\n",
"\n", "\n",
"#@markdown Once this cell has been executed, you will see\n", "#@markdown Once this cell has been executed, you will see\n",
"#@markdown statistics about the multiple sequence alignment \n", "#@markdown statistics about the multiple sequence alignment \n",
...@@ -275,7 +286,6 @@ ...@@ -275,7 +286,6 @@
"\n", "\n",
"from alphafold.data import feature_processing\n", "from alphafold.data import feature_processing\n",
"from alphafold.data import msa_pairing\n", "from alphafold.data import msa_pairing\n",
"from alphafold.data import parsers\n",
"from alphafold.data import pipeline\n", "from alphafold.data import pipeline\n",
"from alphafold.data import pipeline_multimer\n", "from alphafold.data import pipeline_multimer\n",
"from alphafold.data.tools import jackhmmer\n", "from alphafold.data.tools import jackhmmer\n",
...@@ -455,7 +465,7 @@ ...@@ -455,7 +465,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Run AlphaFold and download prediction\n", "#@title 5. Run AlphaFold and download prediction\n",
"\n", "\n",
"#@markdown Once this cell has been executed, a zip-archive with\n", "#@markdown Once this cell has been executed, a zip-archive with\n",
"#@markdown the obtained prediction will be automatically downloaded\n", "#@markdown the obtained prediction will be automatically downloaded\n",
...@@ -542,7 +552,8 @@ ...@@ -542,7 +552,8 @@
" tolerance=2.39,\n", " tolerance=2.39,\n",
" stiffness=10.0,\n", " stiffness=10.0,\n",
" exclude_residues=[],\n", " exclude_residues=[],\n",
" max_outer_iterations=3)\n", " max_outer_iterations=3,\n",
" use_gpu=True)\n",
" relaxed_pdb, _, _ = amber_relaxer.process(prot=unrelaxed_proteins[best_model_name])\n", " relaxed_pdb, _, _ = amber_relaxer.process(prot=unrelaxed_proteins[best_model_name])\n",
" else:\n", " else:\n",
" print('Warning: Running without the relaxation stage.')\n", " print('Warning: Running without the relaxation stage.')\n",
...@@ -694,7 +705,7 @@ ...@@ -694,7 +705,7 @@
"* How do I get a predicted protein structure for my protein?\n", "* How do I get a predicted protein structure for my protein?\n",
" * Click on the _Connect_ button on the top right to get started.\n", " * Click on the _Connect_ button on the top right to get started.\n",
" * Paste the amino acid sequence of your protein (without any headers) into the “Enter the amino acid sequence to fold”.\n", " * Paste the amino acid sequence of your protein (without any headers) into the “Enter the amino acid sequence to fold”.\n",
" * Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ \u003e _Run all._\n", " * Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ \u003e _Run all._ Make sure you run all 5 cells in order.\n",
" * The predicted protein structure will be downloaded once all cells have been executed. Note: This can take minutes to hours - see below.\n", " * The predicted protein structure will be downloaded once all cells have been executed. Note: This can take minutes to hours - see below.\n",
"* How long will this take?\n", "* How long will this take?\n",
" * Downloading the AlphaFold source code can take up to a few minutes.\n", " * Downloading the AlphaFold source code can take up to a few minutes.\n",
......
...@@ -34,11 +34,11 @@ from alphafold.data import templates ...@@ -34,11 +34,11 @@ from alphafold.data import templates
from alphafold.data.tools import hhsearch from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch from alphafold.data.tools import hmmsearch
from alphafold.model import config from alphafold.model import config
from alphafold.model import data
from alphafold.model import model from alphafold.model import model
from alphafold.relax import relax from alphafold.relax import relax
import numpy as np import numpy as np
from alphafold.model import data
# Internal import (7716). # Internal import (7716).
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
...@@ -114,8 +114,21 @@ flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' ...@@ -114,8 +114,21 @@ flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
'deterministic, because processes like GPU inference are ' 'deterministic, because processes like GPU inference are '
'nondeterministic.') 'nondeterministic.')
flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that ' flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
'have been written to disk. WARNING: This will not check ' 'have been written to disk instead of running the MSA '
'if the sequence, database or configuration have changed.') 'tools. The MSA files are looked up in the output '
'directory, so it must stay the same between multiple '
'runs that are to reuse the MSAs. WARNING: This will not '
'check if the sequence, database or configuration have '
'changed.')
flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation '
'step on the predicted models. Turning relax off might '
'result in predictions with distracting stereochemical '
'violations but might help in case you are having issues '
'with the relaxation stage.')
flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
'Relax on GPU can be much faster than CPU, so it is '
'recommended to enable if possible. GPUs must be available'
' if this setting is enabled.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -384,12 +397,16 @@ def main(argv): ...@@ -384,12 +397,16 @@ def main(argv):
logging.info('Have %d models: %s', len(model_runners), logging.info('Have %d models: %s', len(model_runners),
list(model_runners.keys())) list(model_runners.keys()))
amber_relaxer = relax.AmberRelaxation( if FLAGS.run_relax:
max_iterations=RELAX_MAX_ITERATIONS, amber_relaxer = relax.AmberRelaxation(
tolerance=RELAX_ENERGY_TOLERANCE, max_iterations=RELAX_MAX_ITERATIONS,
stiffness=RELAX_STIFFNESS, tolerance=RELAX_ENERGY_TOLERANCE,
exclude_residues=RELAX_EXCLUDE_RESIDUES, stiffness=RELAX_STIFFNESS,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS) exclude_residues=RELAX_EXCLUDE_RESIDUES,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=FLAGS.use_gpu_relax)
else:
amber_relaxer = None
random_seed = FLAGS.random_seed random_seed = FLAGS.random_seed
if random_seed is None: if random_seed is None:
...@@ -422,6 +439,7 @@ if __name__ == '__main__': ...@@ -422,6 +439,7 @@ if __name__ == '__main__':
'template_mmcif_dir', 'template_mmcif_dir',
'max_template_date', 'max_template_date',
'obsolete_pdbs_path', 'obsolete_pdbs_path',
'use_gpu_relax',
]) ])
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment