run_alphafold.py 17.7 KB
Newer Older
Augustin-Zidek's avatar
Augustin-Zidek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Full AlphaFold protein structure prediction script."""
import json
import os
import pathlib
import pickle
import random
21
import shutil
Augustin-Zidek's avatar
Augustin-Zidek committed
22
23
import sys
import time
24
from typing import Dict, Union, Optional
Augustin-Zidek's avatar
Augustin-Zidek committed
25
26
27
28
29

from absl import app
from absl import flags
from absl import logging
from alphafold.common import protein
30
from alphafold.common import residue_constants
Augustin-Zidek's avatar
Augustin-Zidek committed
31
from alphafold.data import pipeline
32
from alphafold.data import pipeline_multimer
Augustin-Zidek's avatar
Augustin-Zidek committed
33
from alphafold.data import templates
34
35
from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch
Augustin-Zidek's avatar
Augustin-Zidek committed
36
from alphafold.model import config
Augustin Zidek's avatar
Augustin Zidek committed
37
from alphafold.model import data
Augustin-Zidek's avatar
Augustin-Zidek committed
38
39
from alphafold.model import model
from alphafold.relax import relax
Tom Ward's avatar
Tom Ward committed
40
import numpy as np
41

Augustin-Zidek's avatar
Augustin-Zidek committed
42
43
# Internal import (7716).

44
45
logging.set_verbosity(logging.INFO)

46
47
48
49
50
51
flags.DEFINE_list(
    'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
    'target that will be folded one after another. If a FASTA file contains '
    'multiple sequences, then it will be folded as a multimer. Paths should be '
    'separated by commas. All FASTA paths must have a unique basename as the '
    'basename is used to name the output directories for each prediction.')
52
53

flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
Augustin-Zidek's avatar
Augustin-Zidek committed
54
55
flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
                    'store the results.')
56
flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'),
Augustin-Zidek's avatar
Augustin-Zidek committed
57
                    'Path to the JackHMMER executable.')
58
flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'),
Augustin-Zidek's avatar
Augustin-Zidek committed
59
                    'Path to the HHblits executable.')
60
flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'),
Augustin-Zidek's avatar
Augustin-Zidek committed
61
                    'Path to the HHsearch executable.')
62
63
64
65
66
flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'),
                    'Path to the hmmsearch executable.')
flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'),
                    'Path to the hmmbuild executable.')
flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'),
Augustin-Zidek's avatar
Augustin-Zidek committed
67
68
69
70
71
72
73
                    'Path to the Kalign executable.')
flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 '
                    'database for use by JackHMMER.')
flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify '
                    'database for use by JackHMMER.')
flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD '
                    'database for use by HHblits.')
74
75
flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small '
                    'version of BFD used with the "reduced_dbs" preset.')
Augustin-Zidek's avatar
Augustin-Zidek committed
76
77
flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 '
                    'database for use by HHblits.')
78
79
flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot '
                    'database for use by JackHMMer.')
Augustin-Zidek's avatar
Augustin-Zidek committed
80
81
flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 '
                    'database for use by HHsearch.')
82
83
flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB '
                    'seqres database for use by hmmsearch.')
Augustin-Zidek's avatar
Augustin-Zidek committed
84
85
86
87
88
89
90
flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with '
                    'template mmCIF structures, each named <pdb_id>.cif')
flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
                    'to consider. Important if folding historical test sets.')
flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a '
                    'mapping from obsolete PDB IDs to the PDB IDs of their '
                    'replacements.')
91
92
93
94
95
96
97
98
99
100
flags.DEFINE_enum('db_preset', 'full_dbs',
                  ['full_dbs', 'reduced_dbs'],
                  'Choose preset MSA database configuration - '
                  'smaller genetic database config (reduced_dbs) or '
                  'full genetic database config  (full_dbs)')
flags.DEFINE_enum('model_preset', 'monomer',
                  ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
                  'Choose preset model configuration - the monomer model, '
                  'the monomer model with extra ensembling, monomer model with '
                  'pTM head, or multimer model')
Augustin-Zidek's avatar
Augustin-Zidek committed
101
102
103
104
105
106
107
108
109
flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
                     'to obtain a timing that excludes the compilation time, '
                     'which should be more indicative of the time required for '
                     'inferencing many proteins.')
flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
                     'pipeline. By default, this is randomly generated. Note '
                     'that even if this is set, Alphafold may still not be '
                     'deterministic, because processes like GPU inference are '
                     'nondeterministic.')
110
111
112
113
114
flags.DEFINE_integer('num_multimer_predictions_per_model', 5, 'How many '
                     'predictions (each with a different random seed) will be '
                     'generated per model. E.g. if this is 2 and there are 5 '
                     'models then there will be 10 predictions per input. '
                     'Note: this FLAG only applies if model_preset=multimer')
115
flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
Augustin Zidek's avatar
Augustin Zidek committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                     'have been written to disk instead of running the MSA '
                     'tools. The MSA files are looked up in the output '
                     'directory, so it must stay the same between multiple '
                     'runs that are to reuse the MSAs. WARNING: This will not '
                     'check if the sequence, database or configuration have '
                     'changed.')
flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation '
                     'step on the predicted models. Turning relax off might '
                     'result in predictions with distracting stereochemical '
                     'violations but might help in case you are having issues '
                     'with the relaxation stage.')
flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
                     'Relax on GPU can be much faster than CPU, so it is '
                     'recommended to enable if possible. GPUs must be available'
                     ' if this setting is enabled.')
131

Augustin-Zidek's avatar
Augustin-Zidek committed
132
133
134
135
136
137
138
FLAGS = flags.FLAGS

MAX_TEMPLATE_HITS = 20
RELAX_MAX_ITERATIONS = 0
RELAX_ENERGY_TOLERANCE = 2.39
RELAX_STIFFNESS = 10.0
RELAX_EXCLUDE_RESIDUES = []
139
RELAX_MAX_OUTER_ITERATIONS = 3
Augustin-Zidek's avatar
Augustin-Zidek committed
140
141


142
143
144
def _check_flag(flag_name: str,
                other_flag_name: str,
                should_be_set: bool):
145
146
  if should_be_set != bool(FLAGS[flag_name].value):
    verb = 'be' if should_be_set else 'not be'
147
148
    raise ValueError(f'{flag_name} must {verb} set when running with '
                     f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
149
150


Augustin-Zidek's avatar
Augustin-Zidek committed
151
152
153
154
def predict_structure(
    fasta_path: str,
    fasta_name: str,
    output_dir_base: str,
155
    data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline],
Augustin-Zidek's avatar
Augustin-Zidek committed
156
157
158
    model_runners: Dict[str, model.RunModel],
    amber_relaxer: relax.AmberRelaxation,
    benchmark: bool,
159
    random_seed: int):
Augustin-Zidek's avatar
Augustin-Zidek committed
160
  """Predicts structure using AlphaFold for the given sequence."""
161
  logging.info('Predicting %s', fasta_name)
Augustin-Zidek's avatar
Augustin-Zidek committed
162
163
164
165
166
167
168
169
170
171
  timings = {}
  output_dir = os.path.join(output_dir_base, fasta_name)
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)
  msa_output_dir = os.path.join(output_dir, 'msas')
  if not os.path.exists(msa_output_dir):
    os.makedirs(msa_output_dir)

  # Get features.
  t_0 = time.time()
172
173
174
  feature_dict = data_pipeline.process(
      input_fasta_path=fasta_path,
      msa_output_dir=msa_output_dir)
Augustin-Zidek's avatar
Augustin-Zidek committed
175
176
177
178
179
180
181
  timings['features'] = time.time() - t_0

  # Write out features as a pickled dictionary.
  features_output_path = os.path.join(output_dir, 'features.pkl')
  with open(features_output_path, 'wb') as f:
    pickle.dump(feature_dict, f, protocol=4)

182
  unrelaxed_pdbs = {}
Augustin-Zidek's avatar
Augustin-Zidek committed
183
  relaxed_pdbs = {}
184
  ranking_confidences = {}
Augustin-Zidek's avatar
Augustin-Zidek committed
185
186

  # Run the models.
187
188
189
190
  num_models = len(model_runners)
  for model_index, (model_name, model_runner) in enumerate(
      model_runners.items()):
    logging.info('Running model %s on %s', model_name, fasta_name)
Augustin-Zidek's avatar
Augustin-Zidek committed
191
    t_0 = time.time()
192
    model_random_seed = model_index + random_seed * num_models
Augustin-Zidek's avatar
Augustin-Zidek committed
193
    processed_feature_dict = model_runner.process_features(
194
        feature_dict, random_seed=model_random_seed)
Augustin-Zidek's avatar
Augustin-Zidek committed
195
196
197
    timings[f'process_features_{model_name}'] = time.time() - t_0

    t_0 = time.time()
198
199
    prediction_result = model_runner.predict(processed_feature_dict,
                                             random_seed=model_random_seed)
Augustin-Zidek's avatar
Augustin-Zidek committed
200
201
202
    t_diff = time.time() - t_0
    timings[f'predict_and_compile_{model_name}'] = t_diff
    logging.info(
203
204
        'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs',
        model_name, fasta_name, t_diff)
Augustin-Zidek's avatar
Augustin-Zidek committed
205
206
207

    if benchmark:
      t_0 = time.time()
208
209
210
211
212
213
214
      model_runner.predict(processed_feature_dict,
                           random_seed=model_random_seed)
      t_diff = time.time() - t_0
      timings[f'predict_benchmark_{model_name}'] = t_diff
      logging.info(
          'Total JAX model %s on %s predict time (excludes compilation time): %.1fs',
          model_name, fasta_name, t_diff)
Augustin-Zidek's avatar
Augustin-Zidek committed
215

216
    plddt = prediction_result['plddt']
217
    ranking_confidences[model_name] = prediction_result['ranking_confidence']
Augustin-Zidek's avatar
Augustin-Zidek committed
218
219
220
221
222
223

    # Save the model outputs.
    result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
    with open(result_output_path, 'wb') as f:
      pickle.dump(prediction_result, f, protocol=4)

224
225
226
227
228
229
230
    # Add the predicted LDDT in the b-factor column.
    # Note that higher predicted LDDT value means higher model confidence.
    plddt_b_factors = np.repeat(
        plddt[:, None], residue_constants.atom_type_num, axis=-1)
    unrelaxed_protein = protein.from_prediction(
        features=processed_feature_dict,
        result=prediction_result,
231
232
        b_factors=plddt_b_factors,
        remove_leading_feature_dimension=not model_runner.multimer_mode)
Augustin-Zidek's avatar
Augustin-Zidek committed
233

234
    unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
Augustin-Zidek's avatar
Augustin-Zidek committed
235
236
    unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
    with open(unrelaxed_pdb_path, 'w') as f:
237
      f.write(unrelaxed_pdbs[model_name])
Augustin-Zidek's avatar
Augustin-Zidek committed
238

239
240
241
242
243
    if amber_relaxer:
      # Relax the prediction.
      t_0 = time.time()
      relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
      timings[f'relax_{model_name}'] = time.time() - t_0
Augustin-Zidek's avatar
Augustin-Zidek committed
244

245
      relaxed_pdbs[model_name] = relaxed_pdb_str
Augustin-Zidek's avatar
Augustin-Zidek committed
246

247
248
249
250
251
      # Save the relaxed PDB.
      relaxed_output_path = os.path.join(
          output_dir, f'relaxed_{model_name}.pdb')
      with open(relaxed_output_path, 'w') as f:
        f.write(relaxed_pdb_str)
Augustin-Zidek's avatar
Augustin-Zidek committed
252

253
  # Rank by model confidence and write out relaxed PDBs in rank order.
Augustin-Zidek's avatar
Augustin-Zidek committed
254
255
  ranked_order = []
  for idx, (model_name, _) in enumerate(
256
      sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
Augustin-Zidek's avatar
Augustin-Zidek committed
257
258
259
    ranked_order.append(model_name)
    ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
    with open(ranked_output_path, 'w') as f:
260
261
262
263
      if amber_relaxer:
        f.write(relaxed_pdbs[model_name])
      else:
        f.write(unrelaxed_pdbs[model_name])
Augustin-Zidek's avatar
Augustin-Zidek committed
264
265
266

  ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
  with open(ranking_output_path, 'w') as f:
267
268
269
    label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts'
    f.write(json.dumps(
        {label: ranking_confidences, 'order': ranked_order}, indent=4))
Augustin-Zidek's avatar
Augustin-Zidek committed
270
271
272
273
274
275
276
277
278
279
280
281

  logging.info('Final timings for %s: %s', fasta_name, timings)

  timings_output_path = os.path.join(output_dir, 'timings.json')
  with open(timings_output_path, 'w') as f:
    f.write(json.dumps(timings, indent=4))


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

282
283
284
285
286
287
288
289
  for tool_name in (
      'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'):
    if not FLAGS[f'{tool_name}_binary_path'].value:
      raise ValueError(f'Could not find path to the "{tool_name}" binary. Make '
                       'sure it is installed on your system.')

  use_small_bfd = FLAGS.db_preset == 'reduced_dbs'
  _check_flag('small_bfd_database_path', 'db_preset',
290
              should_be_set=use_small_bfd)
291
  _check_flag('bfd_database_path', 'db_preset',
292
              should_be_set=not use_small_bfd)
293
  _check_flag('uniclust30_database_path', 'db_preset',
294
295
              should_be_set=not use_small_bfd)

296
297
298
299
300
301
302
303
304
  run_multimer_system = 'multimer' in FLAGS.model_preset
  _check_flag('pdb70_database_path', 'model_preset',
              should_be_set=not run_multimer_system)
  _check_flag('pdb_seqres_database_path', 'model_preset',
              should_be_set=run_multimer_system)
  _check_flag('uniprot_database_path', 'model_preset',
              should_be_set=run_multimer_system)

  if FLAGS.model_preset == 'monomer_casp14':
Augustin-Zidek's avatar
Augustin-Zidek committed
305
    num_ensemble = 8
306
307
  else:
    num_ensemble = 1
Augustin-Zidek's avatar
Augustin-Zidek committed
308
309
310
311
312
313

  # Check for duplicate FASTA file names.
  fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
  if len(fasta_names) != len(set(fasta_names)):
    raise ValueError('All FASTA paths must have a unique basename.')

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
  if run_multimer_system:
    template_searcher = hmmsearch.Hmmsearch(
        binary_path=FLAGS.hmmsearch_binary_path,
        hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
        database_path=FLAGS.pdb_seqres_database_path)
    template_featurizer = templates.HmmsearchHitFeaturizer(
        mmcif_dir=FLAGS.template_mmcif_dir,
        max_template_date=FLAGS.max_template_date,
        max_hits=MAX_TEMPLATE_HITS,
        kalign_binary_path=FLAGS.kalign_binary_path,
        release_dates_path=None,
        obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
  else:
    template_searcher = hhsearch.HHSearch(
        binary_path=FLAGS.hhsearch_binary_path,
        databases=[FLAGS.pdb70_database_path])
    template_featurizer = templates.HhsearchHitFeaturizer(
        mmcif_dir=FLAGS.template_mmcif_dir,
        max_template_date=FLAGS.max_template_date,
        max_hits=MAX_TEMPLATE_HITS,
        kalign_binary_path=FLAGS.kalign_binary_path,
        release_dates_path=None,
        obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)

  monomer_data_pipeline = pipeline.DataPipeline(
Augustin-Zidek's avatar
Augustin-Zidek committed
339
340
341
342
343
344
      jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
      hhblits_binary_path=FLAGS.hhblits_binary_path,
      uniref90_database_path=FLAGS.uniref90_database_path,
      mgnify_database_path=FLAGS.mgnify_database_path,
      bfd_database_path=FLAGS.bfd_database_path,
      uniclust30_database_path=FLAGS.uniclust30_database_path,
345
      small_bfd_database_path=FLAGS.small_bfd_database_path,
346
      template_searcher=template_searcher,
347
      template_featurizer=template_featurizer,
348
349
350
351
      use_small_bfd=use_small_bfd,
      use_precomputed_msas=FLAGS.use_precomputed_msas)

  if run_multimer_system:
352
    num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
353
354
355
356
357
358
    data_pipeline = pipeline_multimer.DataPipeline(
        monomer_data_pipeline=monomer_data_pipeline,
        jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
        uniprot_database_path=FLAGS.uniprot_database_path,
        use_precomputed_msas=FLAGS.use_precomputed_msas)
  else:
359
    num_predictions_per_model = 1
360
    data_pipeline = monomer_data_pipeline
Augustin-Zidek's avatar
Augustin-Zidek committed
361
362

  model_runners = {}
363
364
  model_names = config.MODEL_PRESETS[FLAGS.model_preset]
  for model_name in model_names:
Augustin-Zidek's avatar
Augustin-Zidek committed
365
    model_config = config.model_config(model_name)
366
367
368
369
    if run_multimer_system:
      model_config.model.num_ensemble_eval = num_ensemble
    else:
      model_config.data.eval.num_ensemble = num_ensemble
Augustin-Zidek's avatar
Augustin-Zidek committed
370
371
372
    model_params = data.get_model_haiku_params(
        model_name=model_name, data_dir=FLAGS.data_dir)
    model_runner = model.RunModel(model_config, model_params)
373
374
    for i in range(num_predictions_per_model):
      model_runners[f'{model_name}_pred_{i}'] = model_runner
Augustin-Zidek's avatar
Augustin-Zidek committed
375
376
377
378

  logging.info('Have %d models: %s', len(model_runners),
               list(model_runners.keys()))

Augustin Zidek's avatar
Augustin Zidek committed
379
380
381
382
383
384
385
386
387
388
  if FLAGS.run_relax:
    amber_relaxer = relax.AmberRelaxation(
        max_iterations=RELAX_MAX_ITERATIONS,
        tolerance=RELAX_ENERGY_TOLERANCE,
        stiffness=RELAX_STIFFNESS,
        exclude_residues=RELAX_EXCLUDE_RESIDUES,
        max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
        use_gpu=FLAGS.use_gpu_relax)
  else:
    amber_relaxer = None
Augustin-Zidek's avatar
Augustin-Zidek committed
389
390
391

  random_seed = FLAGS.random_seed
  if random_seed is None:
392
    random_seed = random.randrange(sys.maxsize // len(model_runners))
Augustin-Zidek's avatar
Augustin-Zidek committed
393
394
395
  logging.info('Using random seed %d for the data pipeline', random_seed)

  # Predict structure for each of the sequences.
396
397
  for i, fasta_path in enumerate(FLAGS.fasta_paths):
    fasta_name = fasta_names[i]
Augustin-Zidek's avatar
Augustin-Zidek committed
398
399
400
401
402
403
404
405
    predict_structure(
        fasta_path=fasta_path,
        fasta_name=fasta_name,
        output_dir_base=FLAGS.output_dir,
        data_pipeline=data_pipeline,
        model_runners=model_runners,
        amber_relaxer=amber_relaxer,
        benchmark=FLAGS.benchmark,
406
        random_seed=random_seed)
Augustin-Zidek's avatar
Augustin-Zidek committed
407
408
409
410
411
412
413
414
415
416
417
418


if __name__ == '__main__':
  flags.mark_flags_as_required([
      'fasta_paths',
      'output_dir',
      'data_dir',
      'uniref90_database_path',
      'mgnify_database_path',
      'template_mmcif_dir',
      'max_template_date',
      'obsolete_pdbs_path',
Augustin Zidek's avatar
Augustin Zidek committed
419
      'use_gpu_relax',
Augustin-Zidek's avatar
Augustin-Zidek committed
420
421
422
  ])

  app.run(main)