Commit 0be2b30b authored by Augustin-Zidek's avatar Augustin-Zidek
Browse files

Add code for AlphaFold-Multimer.

PiperOrigin-RevId: 407076987
parent 1d43aaff
...@@ -8,5 +8,6 @@ immutabledict==2.0.0 ...@@ -8,5 +8,6 @@ immutabledict==2.0.0
jax==0.2.14 jax==0.2.14
ml-collections==0.1.0 ml-collections==0.1.0
numpy==1.19.5 numpy==1.19.5
pandas==1.3.4
scipy==1.7.0 scipy==1.7.0
tensorflow-cpu==2.5.0 tensorflow-cpu==2.5.0
...@@ -18,9 +18,10 @@ import os ...@@ -18,9 +18,10 @@ import os
import pathlib import pathlib
import pickle import pickle
import random import random
import shutil
import sys import sys
import time import time
from typing import Dict from typing import Dict, Union, Optional
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -28,30 +29,47 @@ from absl import logging ...@@ -28,30 +29,47 @@ from absl import logging
from alphafold.common import protein from alphafold.common import protein
from alphafold.common import residue_constants from alphafold.common import residue_constants
from alphafold.data import pipeline from alphafold.data import pipeline
from alphafold.data import pipeline_multimer
from alphafold.data import templates from alphafold.data import templates
from alphafold.model import data from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch
from alphafold.model import config from alphafold.model import config
from alphafold.model import model from alphafold.model import model
from alphafold.relax import relax from alphafold.relax import relax
import numpy as np import numpy as np
from alphafold.model import data
# Internal import (7716). # Internal import (7716).
logging.set_verbosity(logging.INFO)
flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing ' flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing '
'one sequence. Paths should be separated by commas. ' 'a prediction target. Paths should be separated by commas. '
'All FASTA paths must have a unique basename as the ' 'All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for ' 'basename is used to name the output directories for '
'each prediction.') 'each prediction.')
flags.DEFINE_list('is_prokaryote_list', None, 'Optional for multimer system, '
'not used by the single chain system. '
'This list should contain a boolean for each fasta '
'specifying true where the target complex is from a '
'prokaryote, and false where it is not, or where the '
'origin is unknown. These values determine the pairing '
'method for the MSA.')
flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
'store the results.') 'store the results.')
flags.DEFINE_list('model_names', None, 'Names of models to use.') flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'),
flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
flags.DEFINE_string('jackhmmer_binary_path', '/usr/bin/jackhmmer',
'Path to the JackHMMER executable.') 'Path to the JackHMMER executable.')
flags.DEFINE_string('hhblits_binary_path', '/usr/bin/hhblits', flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'),
'Path to the HHblits executable.') 'Path to the HHblits executable.')
flags.DEFINE_string('hhsearch_binary_path', '/usr/bin/hhsearch', flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'),
'Path to the HHsearch executable.') 'Path to the HHsearch executable.')
flags.DEFINE_string('kalign_binary_path', '/usr/bin/kalign', flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'),
'Path to the hmmsearch executable.')
flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'),
'Path to the hmmbuild executable.')
flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'),
'Path to the Kalign executable.') 'Path to the Kalign executable.')
flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 ' flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 '
'database for use by JackHMMER.') 'database for use by JackHMMER.')
...@@ -63,8 +81,12 @@ flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small ' ...@@ -63,8 +81,12 @@ flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small '
'version of BFD used with the "reduced_dbs" preset.') 'version of BFD used with the "reduced_dbs" preset.')
flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 ' flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 '
'database for use by HHblits.') 'database for use by HHblits.')
flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot '
'database for use by JackHMMer.')
flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 ' flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 '
'database for use by HHsearch.') 'database for use by HHsearch.')
flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB '
'seqres database for use by hmmsearch.')
flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with ' flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with '
'template mmCIF structures, each named <pdb_id>.cif') 'template mmCIF structures, each named <pdb_id>.cif')
flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
...@@ -72,13 +94,16 @@ flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' ...@@ -72,13 +94,16 @@ flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a ' flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a '
'mapping from obsolete PDB IDs to the PDB IDs of their ' 'mapping from obsolete PDB IDs to the PDB IDs of their '
'replacements.') 'replacements.')
flags.DEFINE_enum('preset', 'full_dbs', flags.DEFINE_enum('db_preset', 'full_dbs',
['reduced_dbs', 'full_dbs', 'casp14'], ['full_dbs', 'reduced_dbs'],
'Choose preset model configuration - no ensembling and ' 'Choose preset MSA database configuration - '
'smaller genetic database config (reduced_dbs), no ' 'smaller genetic database config (reduced_dbs) or '
'ensembling and full genetic database config (full_dbs) or ' 'full genetic database config (full_dbs)')
'full genetic database config and 8 model ensemblings ' flags.DEFINE_enum('model_preset', 'monomer',
'(casp14).') ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
'Choose preset model configuration - the monomer model, '
'the monomer model with extra ensembling, monomer model with '
'pTM head, or multimer model')
flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
'to obtain a timing that excludes the compilation time, ' 'to obtain a timing that excludes the compilation time, '
'which should be more indicative of the time required for ' 'which should be more indicative of the time required for '
...@@ -88,6 +113,10 @@ flags.DEFINE_integer('random_seed', None, 'The random seed for the data ' ...@@ -88,6 +113,10 @@ flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
'that even if this is set, Alphafold may still not be ' 'that even if this is set, Alphafold may still not be '
'deterministic, because processes like GPU inference are ' 'deterministic, because processes like GPU inference are '
'nondeterministic.') 'nondeterministic.')
flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
'have been written to disk. WARNING: This will not check '
'if the sequence, database or configuration have changed.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
MAX_TEMPLATE_HITS = 20 MAX_TEMPLATE_HITS = 20
...@@ -95,25 +124,30 @@ RELAX_MAX_ITERATIONS = 0 ...@@ -95,25 +124,30 @@ RELAX_MAX_ITERATIONS = 0
RELAX_ENERGY_TOLERANCE = 2.39 RELAX_ENERGY_TOLERANCE = 2.39
RELAX_STIFFNESS = 10.0 RELAX_STIFFNESS = 10.0
RELAX_EXCLUDE_RESIDUES = [] RELAX_EXCLUDE_RESIDUES = []
RELAX_MAX_OUTER_ITERATIONS = 20 RELAX_MAX_OUTER_ITERATIONS = 3
def _check_flag(flag_name: str, preset: str, should_be_set: bool): def _check_flag(flag_name: str,
other_flag_name: str,
should_be_set: bool):
if should_be_set != bool(FLAGS[flag_name].value): if should_be_set != bool(FLAGS[flag_name].value):
verb = 'be' if should_be_set else 'not be' verb = 'be' if should_be_set else 'not be'
raise ValueError(f'{flag_name} must {verb} set for preset "{preset}"') raise ValueError(f'{flag_name} must {verb} set when running with '
f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
def predict_structure( def predict_structure(
fasta_path: str, fasta_path: str,
fasta_name: str, fasta_name: str,
output_dir_base: str, output_dir_base: str,
data_pipeline: pipeline.DataPipeline, data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline],
model_runners: Dict[str, model.RunModel], model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation, amber_relaxer: relax.AmberRelaxation,
benchmark: bool, benchmark: bool,
random_seed: int): random_seed: int,
is_prokaryote: Optional[bool] = None):
"""Predicts structure using AlphaFold for the given sequence.""" """Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name)
timings = {} timings = {}
output_dir = os.path.join(output_dir_base, fasta_name) output_dir = os.path.join(output_dir_base, fasta_name)
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
...@@ -124,9 +158,15 @@ def predict_structure( ...@@ -124,9 +158,15 @@ def predict_structure(
# Get features. # Get features.
t_0 = time.time() t_0 = time.time()
feature_dict = data_pipeline.process( if is_prokaryote is None:
input_fasta_path=fasta_path, feature_dict = data_pipeline.process(
msa_output_dir=msa_output_dir) input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir)
else:
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path,
msa_output_dir=msa_output_dir,
is_prokaryote=is_prokaryote)
timings['features'] = time.time() - t_0 timings['features'] = time.time() - t_0
# Write out features as a pickled dictionary. # Write out features as a pickled dictionary.
...@@ -134,33 +174,42 @@ def predict_structure( ...@@ -134,33 +174,42 @@ def predict_structure(
with open(features_output_path, 'wb') as f: with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4) pickle.dump(feature_dict, f, protocol=4)
unrelaxed_pdbs = {}
relaxed_pdbs = {} relaxed_pdbs = {}
plddts = {} ranking_confidences = {}
# Run the models. # Run the models.
for model_name, model_runner in model_runners.items(): num_models = len(model_runners)
logging.info('Running model %s', model_name) for model_index, (model_name, model_runner) in enumerate(
model_runners.items()):
logging.info('Running model %s on %s', model_name, fasta_name)
t_0 = time.time() t_0 = time.time()
model_random_seed = model_index + random_seed * num_models
processed_feature_dict = model_runner.process_features( processed_feature_dict = model_runner.process_features(
feature_dict, random_seed=random_seed) feature_dict, random_seed=model_random_seed)
timings[f'process_features_{model_name}'] = time.time() - t_0 timings[f'process_features_{model_name}'] = time.time() - t_0
t_0 = time.time() t_0 = time.time()
prediction_result = model_runner.predict(processed_feature_dict) prediction_result = model_runner.predict(processed_feature_dict,
random_seed=model_random_seed)
t_diff = time.time() - t_0 t_diff = time.time() - t_0
timings[f'predict_and_compile_{model_name}'] = t_diff timings[f'predict_and_compile_{model_name}'] = t_diff
logging.info( logging.info(
'Total JAX model %s predict time (includes compilation time, see --benchmark): %.0f?', 'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs',
model_name, t_diff) model_name, fasta_name, t_diff)
if benchmark: if benchmark:
t_0 = time.time() t_0 = time.time()
model_runner.predict(processed_feature_dict) model_runner.predict(processed_feature_dict,
timings[f'predict_benchmark_{model_name}'] = time.time() - t_0 random_seed=model_random_seed)
t_diff = time.time() - t_0
timings[f'predict_benchmark_{model_name}'] = t_diff
logging.info(
'Total JAX model %s on %s predict time (excludes compilation time): %.1fs',
model_name, fasta_name, t_diff)
# Get mean pLDDT confidence metric.
plddt = prediction_result['plddt'] plddt = prediction_result['plddt']
plddts[model_name] = np.mean(plddt) ranking_confidences[model_name] = prediction_result['ranking_confidence']
# Save the model outputs. # Save the model outputs.
result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
...@@ -174,36 +223,45 @@ def predict_structure( ...@@ -174,36 +223,45 @@ def predict_structure(
unrelaxed_protein = protein.from_prediction( unrelaxed_protein = protein.from_prediction(
features=processed_feature_dict, features=processed_feature_dict,
result=prediction_result, result=prediction_result,
b_factors=plddt_b_factors) b_factors=plddt_b_factors,
remove_leading_feature_dimension=not model_runner.multimer_mode)
unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
with open(unrelaxed_pdb_path, 'w') as f: with open(unrelaxed_pdb_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(unrelaxed_pdbs[model_name])
# Relax the prediction. if amber_relaxer:
t_0 = time.time() # Relax the prediction.
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) t_0 = time.time()
timings[f'relax_{model_name}'] = time.time() - t_0 relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
timings[f'relax_{model_name}'] = time.time() - t_0
relaxed_pdbs[model_name] = relaxed_pdb_str relaxed_pdbs[model_name] = relaxed_pdb_str
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb') relaxed_output_path = os.path.join(
with open(relaxed_output_path, 'w') as f: output_dir, f'relaxed_{model_name}.pdb')
f.write(relaxed_pdb_str) with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
# Rank by pLDDT and write out relaxed PDBs in rank order. # Rank by model confidence and write out relaxed PDBs in rank order.
ranked_order = [] ranked_order = []
for idx, (model_name, _) in enumerate( for idx, (model_name, _) in enumerate(
sorted(plddts.items(), key=lambda x: x[1], reverse=True)): sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
ranked_order.append(model_name) ranked_order.append(model_name)
ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
with open(ranked_output_path, 'w') as f: with open(ranked_output_path, 'w') as f:
f.write(relaxed_pdbs[model_name]) if amber_relaxer:
f.write(relaxed_pdbs[model_name])
else:
f.write(unrelaxed_pdbs[model_name])
ranking_output_path = os.path.join(output_dir, 'ranking_debug.json') ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
with open(ranking_output_path, 'w') as f: with open(ranking_output_path, 'w') as f:
f.write(json.dumps({'plddts': plddts, 'order': ranked_order}, indent=4)) label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts'
f.write(json.dumps(
{label: ranking_confidences, 'order': ranked_order}, indent=4))
logging.info('Final timings for %s: %s', fasta_name, timings) logging.info('Final timings for %s: %s', fasta_name, timings)
...@@ -216,49 +274,108 @@ def main(argv): ...@@ -216,49 +274,108 @@ def main(argv):
if len(argv) > 1: if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.') raise app.UsageError('Too many command-line arguments.')
use_small_bfd = FLAGS.preset == 'reduced_dbs' for tool_name in (
_check_flag('small_bfd_database_path', FLAGS.preset, 'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'):
if not FLAGS[f'{tool_name}_binary_path'].value:
raise ValueError(f'Could not find path to the "{tool_name}" binary. Make '
'sure it is installed on your system.')
use_small_bfd = FLAGS.db_preset == 'reduced_dbs'
_check_flag('small_bfd_database_path', 'db_preset',
should_be_set=use_small_bfd) should_be_set=use_small_bfd)
_check_flag('bfd_database_path', FLAGS.preset, _check_flag('bfd_database_path', 'db_preset',
should_be_set=not use_small_bfd) should_be_set=not use_small_bfd)
_check_flag('uniclust30_database_path', FLAGS.preset, _check_flag('uniclust30_database_path', 'db_preset',
should_be_set=not use_small_bfd) should_be_set=not use_small_bfd)
if FLAGS.preset in ('reduced_dbs', 'full_dbs'): run_multimer_system = 'multimer' in FLAGS.model_preset
num_ensemble = 1 _check_flag('pdb70_database_path', 'model_preset',
elif FLAGS.preset == 'casp14': should_be_set=not run_multimer_system)
_check_flag('pdb_seqres_database_path', 'model_preset',
should_be_set=run_multimer_system)
_check_flag('uniprot_database_path', 'model_preset',
should_be_set=run_multimer_system)
if FLAGS.model_preset == 'monomer_casp14':
num_ensemble = 8 num_ensemble = 8
else:
num_ensemble = 1
# Check for duplicate FASTA file names. # Check for duplicate FASTA file names.
fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths] fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
if len(fasta_names) != len(set(fasta_names)): if len(fasta_names) != len(set(fasta_names)):
raise ValueError('All FASTA paths must have a unique basename.') raise ValueError('All FASTA paths must have a unique basename.')
template_featurizer = templates.TemplateHitFeaturizer( # Check that is_prokaryote_list has same number of elements as fasta_paths,
mmcif_dir=FLAGS.template_mmcif_dir, # and convert to bool.
max_template_date=FLAGS.max_template_date, if FLAGS.is_prokaryote_list:
max_hits=MAX_TEMPLATE_HITS, if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths):
kalign_binary_path=FLAGS.kalign_binary_path, raise ValueError('--is_prokaryote_list must either be omitted or match '
release_dates_path=None, 'length of --fasta_paths.')
obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) is_prokaryote_list = []
for s in FLAGS.is_prokaryote_list:
data_pipeline = pipeline.DataPipeline( if s in ('true', 'false'):
is_prokaryote_list.append(s == 'true')
else:
raise ValueError('--is_prokaryote_list must contain comma separated '
'true or false values.')
else: # Default is_prokaryote to False.
is_prokaryote_list = [False] * len(fasta_names)
if run_multimer_system:
template_searcher = hmmsearch.Hmmsearch(
binary_path=FLAGS.hmmsearch_binary_path,
hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
database_path=FLAGS.pdb_seqres_database_path)
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir,
max_template_date=FLAGS.max_template_date,
max_hits=MAX_TEMPLATE_HITS,
kalign_binary_path=FLAGS.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
else:
template_searcher = hhsearch.HHSearch(
binary_path=FLAGS.hhsearch_binary_path,
databases=[FLAGS.pdb70_database_path])
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir,
max_template_date=FLAGS.max_template_date,
max_hits=MAX_TEMPLATE_HITS,
kalign_binary_path=FLAGS.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
monomer_data_pipeline = pipeline.DataPipeline(
jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
hhblits_binary_path=FLAGS.hhblits_binary_path, hhblits_binary_path=FLAGS.hhblits_binary_path,
hhsearch_binary_path=FLAGS.hhsearch_binary_path,
uniref90_database_path=FLAGS.uniref90_database_path, uniref90_database_path=FLAGS.uniref90_database_path,
mgnify_database_path=FLAGS.mgnify_database_path, mgnify_database_path=FLAGS.mgnify_database_path,
bfd_database_path=FLAGS.bfd_database_path, bfd_database_path=FLAGS.bfd_database_path,
uniclust30_database_path=FLAGS.uniclust30_database_path, uniclust30_database_path=FLAGS.uniclust30_database_path,
small_bfd_database_path=FLAGS.small_bfd_database_path, small_bfd_database_path=FLAGS.small_bfd_database_path,
pdb70_database_path=FLAGS.pdb70_database_path, template_searcher=template_searcher,
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd) use_small_bfd=use_small_bfd,
use_precomputed_msas=FLAGS.use_precomputed_msas)
if run_multimer_system:
data_pipeline = pipeline_multimer.DataPipeline(
monomer_data_pipeline=monomer_data_pipeline,
jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
uniprot_database_path=FLAGS.uniprot_database_path,
use_precomputed_msas=FLAGS.use_precomputed_msas)
else:
data_pipeline = monomer_data_pipeline
model_runners = {} model_runners = {}
for model_name in FLAGS.model_names: model_names = config.MODEL_PRESETS[FLAGS.model_preset]
for model_name in model_names:
model_config = config.model_config(model_name) model_config = config.model_config(model_name)
model_config.data.eval.num_ensemble = num_ensemble if run_multimer_system:
model_config.model.num_ensemble_eval = num_ensemble
else:
model_config.data.eval.num_ensemble = num_ensemble
model_params = data.get_model_haiku_params( model_params = data.get_model_haiku_params(
model_name=model_name, data_dir=FLAGS.data_dir) model_name=model_name, data_dir=FLAGS.data_dir)
model_runner = model.RunModel(model_config, model_params) model_runner = model.RunModel(model_config, model_params)
...@@ -276,11 +393,13 @@ def main(argv): ...@@ -276,11 +393,13 @@ def main(argv):
random_seed = FLAGS.random_seed random_seed = FLAGS.random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize // len(model_names))
logging.info('Using random seed %d for the data pipeline', random_seed) logging.info('Using random seed %d for the data pipeline', random_seed)
# Predict structure for each of the sequences. # Predict structure for each of the sequences.
for fasta_path, fasta_name in zip(FLAGS.fasta_paths, fasta_names): for i, fasta_path in enumerate(FLAGS.fasta_paths):
is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None
fasta_name = fasta_names[i]
predict_structure( predict_structure(
fasta_path=fasta_path, fasta_path=fasta_path,
fasta_name=fasta_name, fasta_name=fasta_name,
...@@ -289,19 +408,17 @@ def main(argv): ...@@ -289,19 +408,17 @@ def main(argv):
model_runners=model_runners, model_runners=model_runners,
amber_relaxer=amber_relaxer, amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark, benchmark=FLAGS.benchmark,
random_seed=random_seed) random_seed=random_seed,
is_prokaryote=is_prokaryote)
if __name__ == '__main__': if __name__ == '__main__':
flags.mark_flags_as_required([ flags.mark_flags_as_required([
'fasta_paths', 'fasta_paths',
'output_dir', 'output_dir',
'model_names',
'data_dir', 'data_dir',
'preset',
'uniref90_database_path', 'uniref90_database_path',
'mgnify_database_path', 'mgnify_database_path',
'pdb70_database_path',
'template_mmcif_dir', 'template_mmcif_dir',
'max_template_date', 'max_template_date',
'obsolete_pdbs_path', 'obsolete_pdbs_path',
......
...@@ -26,7 +26,11 @@ import numpy as np ...@@ -26,7 +26,11 @@ import numpy as np
class RunAlphafoldTest(parameterized.TestCase): class RunAlphafoldTest(parameterized.TestCase):
def test_end_to_end(self): @parameterized.named_parameters(
('relax', True),
('no_relax', False),
)
def test_end_to_end(self, do_relax):
data_pipeline_mock = mock.Mock() data_pipeline_mock = mock.Mock()
model_runner_mock = mock.Mock() model_runner_mock = mock.Mock()
...@@ -46,11 +50,13 @@ class RunAlphafoldTest(parameterized.TestCase): ...@@ -46,11 +50,13 @@ class RunAlphafoldTest(parameterized.TestCase):
'logits': np.ones((10, 50)), 'logits': np.ones((10, 50)),
}, },
'plddt': np.ones(10) * 42, 'plddt': np.ones(10) * 42,
'ranking_confidence': 90,
'ptm': np.array(0.), 'ptm': np.array(0.),
'aligned_confidence_probs': np.zeros((10, 10, 50)), 'aligned_confidence_probs': np.zeros((10, 10, 50)),
'predicted_aligned_error': np.zeros((10, 10)), 'predicted_aligned_error': np.zeros((10, 10)),
'max_predicted_aligned_error': np.array(0.), 'max_predicted_aligned_error': np.array(0.),
} }
model_runner_mock.multimer_mode = False
amber_relaxer_mock.process.return_value = ('RELAXED', None, None) amber_relaxer_mock.process.return_value = ('RELAXED', None, None)
fasta_path = os.path.join(absltest.get_default_test_tmpdir(), fasta_path = os.path.join(absltest.get_default_test_tmpdir(),
...@@ -67,7 +73,7 @@ class RunAlphafoldTest(parameterized.TestCase): ...@@ -67,7 +73,7 @@ class RunAlphafoldTest(parameterized.TestCase):
output_dir_base=out_dir, output_dir_base=out_dir,
data_pipeline=data_pipeline_mock, data_pipeline=data_pipeline_mock,
model_runners={'model1': model_runner_mock}, model_runners={'model1': model_runner_mock},
amber_relaxer=amber_relaxer_mock, amber_relaxer=amber_relaxer_mock if do_relax else None,
benchmark=False, benchmark=False,
random_seed=0) random_seed=0)
...@@ -76,10 +82,13 @@ class RunAlphafoldTest(parameterized.TestCase): ...@@ -76,10 +82,13 @@ class RunAlphafoldTest(parameterized.TestCase):
self.assertIn('test', base_output_files) self.assertIn('test', base_output_files)
target_output_files = os.listdir(os.path.join(out_dir, 'test')) target_output_files = os.listdir(os.path.join(out_dir, 'test'))
self.assertCountEqual( expected_files = [
['features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json',
'relaxed_model1.pdb', 'result_model1.pkl', 'timings.json', 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb',
'unrelaxed_model1.pdb'], target_output_files) ]
if do_relax:
expected_files.append('relaxed_model1.pdb')
self.assertCountEqual(expected_files, target_output_files)
# Check that pLDDT is set in the B-factor column. # Check that pLDDT is set in the B-factor column.
with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f: with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f:
......
...@@ -20,17 +20,17 @@ ...@@ -20,17 +20,17 @@
set -e set -e
if [[ $# -eq 0 ]]; then if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument." echo "Error: download directory must be provided as an input argument."
exit 1 exit 1
fi fi
if ! command -v aria2c &> /dev/null ; then if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1 exit 1
fi fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs.
if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]]
then then
echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized."
...@@ -42,12 +42,12 @@ SCRIPT_DIR="$(dirname "$(realpath "$0")")" ...@@ -42,12 +42,12 @@ SCRIPT_DIR="$(dirname "$(realpath "$0")")"
echo "Downloading AlphaFold parameters..." echo "Downloading AlphaFold parameters..."
bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}"
if [[ "${DOWNLOAD_MODE}" = full_dbs ]] ; then if [[ "${DOWNLOAD_MODE}" = reduced_dbs ]] ; then
echo "Downloading BFD..."
bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}"
else
echo "Downloading Small BFD..." echo "Downloading Small BFD..."
bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}"
else
echo "Downloading BFD..."
bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}"
fi fi
echo "Downloading MGnify..." echo "Downloading MGnify..."
...@@ -65,4 +65,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" ...@@ -65,4 +65,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
echo "Downloading Uniref90..." echo "Downloading Uniref90..."
bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
echo "Downloading UniProt..."
bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}"
echo "Downloading PDB SeqRes..."
bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}"
echo "All data downloaded." echo "All data downloaded."
...@@ -31,7 +31,7 @@ fi ...@@ -31,7 +31,7 @@ fi
DOWNLOAD_DIR="$1" DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/params" ROOT_DIR="${DOWNLOAD_DIR}/params"
SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar" SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar"
BASENAME=$(basename "${SOURCE_URL}") BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}" mkdir --parents "${ROOT_DIR}"
......
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads and unzips the PDB SeqRes database for AlphaFold.
#
# Usage: bash download_pdb_seqres.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres"
SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt"
BASENAME=$(basename "${SOURCE_URL}")
mkdir --parents "${ROOT_DIR}"
aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
#!/bin/bash
#
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Downloads, unzips and merges the SwissProt and TrEMBL databases for
# AlphaFold-Multimer.
#
# Usage: bash download_uniprot.sh /path/to/download/directory
set -e
if [[ $# -eq 0 ]]; then
echo "Error: download directory must be provided as an input argument."
exit 1
fi
if ! command -v aria2c &> /dev/null ; then
echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
exit 1
fi
DOWNLOAD_DIR="$1"
ROOT_DIR="${DOWNLOAD_DIR}/uniprot"
TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz"
TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}")
TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}"
SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz"
SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}")
SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}"
mkdir --parents "${ROOT_DIR}"
aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}"
aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}"
pushd "${ROOT_DIR}"
gunzip "${ROOT_DIR}/${TREMBL_BASENAME}"
gunzip "${ROOT_DIR}/${SPROT_BASENAME}"
# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up.
cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}"
mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta"
rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}"
popd
...@@ -18,7 +18,7 @@ from setuptools import setup ...@@ -18,7 +18,7 @@ from setuptools import setup
setup( setup(
name='alphafold', name='alphafold',
version='2.0.0', version='2.1.0',
description='An implementation of the inference pipeline of AlphaFold v2.0.' description='An implementation of the inference pipeline of AlphaFold v2.0.'
'This is a completely new model that was entered as AlphaFold2 in CASP14 ' 'This is a completely new model that was entered as AlphaFold2 in CASP14 '
'and published in Nature.', 'and published in Nature.',
...@@ -38,6 +38,7 @@ setup( ...@@ -38,6 +38,7 @@ setup(
'jax', 'jax',
'ml-collections', 'ml-collections',
'numpy', 'numpy',
'pandas',
'scipy', 'scipy',
'tensorflow-cpu', 'tensorflow-cpu',
], ],
...@@ -49,6 +50,9 @@ setup( ...@@ -49,6 +50,9 @@ setup(
'Operating System :: POSIX :: Linux', 'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
], ],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment