run_alphafold.py 19.5 KB
Newer Older
Augustin-Zidek's avatar
Augustin-Zidek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Full AlphaFold protein structure prediction script."""
16
import enum
Augustin-Zidek's avatar
Augustin-Zidek committed
17
18
19
20
21
import json
import os
import pathlib
import pickle
import random
22
import shutil
Augustin-Zidek's avatar
Augustin-Zidek committed
23
24
import sys
import time
25
from typing import Any, Dict, Mapping, Union
Augustin-Zidek's avatar
Augustin-Zidek committed
26
27
28
29
30

from absl import app
from absl import flags
from absl import logging
from alphafold.common import protein
31
from alphafold.common import residue_constants
Augustin-Zidek's avatar
Augustin-Zidek committed
32
from alphafold.data import pipeline
33
from alphafold.data import pipeline_multimer
Augustin-Zidek's avatar
Augustin-Zidek committed
34
from alphafold.data import templates
35
36
from alphafold.data.tools import hhsearch
from alphafold.data.tools import hmmsearch
Augustin-Zidek's avatar
Augustin-Zidek committed
37
from alphafold.model import config
Augustin Zidek's avatar
Augustin Zidek committed
38
from alphafold.model import data
Augustin-Zidek's avatar
Augustin-Zidek committed
39
40
from alphafold.model import model
from alphafold.relax import relax
41
import jax.numpy as jnp
Tom Ward's avatar
Tom Ward committed
42
import numpy as np
43

Augustin-Zidek's avatar
Augustin-Zidek committed
44
45
# Internal import (7716).

46
47
logging.set_verbosity(logging.INFO)

48
49
50
51
52
53
54

@enum.unique
class ModelsToRelax(enum.Enum):
  ALL = 0
  BEST = 1
  NONE = 2

55
56
57
58
59
60
flags.DEFINE_list(
    'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
    'target that will be folded one after another. If a FASTA file contains '
    'multiple sequences, then it will be folded as a multimer. Paths should be '
    'separated by commas. All FASTA paths must have a unique basename as the '
    'basename is used to name the output directories for each prediction.')
61
62

flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.')
zhuwenwen's avatar
zhuwenwen committed
63
flags.DEFINE_list('model_names', None, 'Names of models to use.')
Augustin-Zidek's avatar
Augustin-Zidek committed
64
65
flags.DEFINE_string('output_dir', None, 'Path to a directory that will '
                    'store the results.')
66
flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'),
Augustin-Zidek's avatar
Augustin-Zidek committed
67
                    'Path to the JackHMMER executable.')
68
flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'),
Augustin-Zidek's avatar
Augustin-Zidek committed
69
                    'Path to the HHblits executable.')
70
flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'),
Augustin-Zidek's avatar
Augustin-Zidek committed
71
                    'Path to the HHsearch executable.')
72
73
74
75
76
flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'),
                    'Path to the hmmsearch executable.')
flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'),
                    'Path to the hmmbuild executable.')
flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'),
Augustin-Zidek's avatar
Augustin-Zidek committed
77
78
79
80
81
82
83
                    'Path to the Kalign executable.')
flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 '
                    'database for use by JackHMMER.')
flags.DEFINE_string('mgnify_database_path', None, 'Path to the MGnify '
                    'database for use by JackHMMER.')
flags.DEFINE_string('bfd_database_path', None, 'Path to the BFD '
                    'database for use by HHblits.')
84
85
flags.DEFINE_string('small_bfd_database_path', None, 'Path to the small '
                    'version of BFD used with the "reduced_dbs" preset.')
Augustin Zidek's avatar
Augustin Zidek committed
86
flags.DEFINE_string('uniref30_database_path', None, 'Path to the UniRef30 '
Augustin-Zidek's avatar
Augustin-Zidek committed
87
                    'database for use by HHblits.')
88
89
flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot '
                    'database for use by JackHMMer.')
Augustin-Zidek's avatar
Augustin-Zidek committed
90
91
flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 '
                    'database for use by HHsearch.')
92
93
flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB '
                    'seqres database for use by hmmsearch.')
Augustin-Zidek's avatar
Augustin-Zidek committed
94
95
96
97
98
99
100
flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with '
                    'template mmCIF structures, each named <pdb_id>.cif')
flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
                    'to consider. Important if folding historical test sets.')
flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a '
                    'mapping from obsolete PDB IDs to the PDB IDs of their '
                    'replacements.')
101
102
103
104
105
106
107
108
109
110
flags.DEFINE_enum('db_preset', 'full_dbs',
                  ['full_dbs', 'reduced_dbs'],
                  'Choose preset MSA database configuration - '
                  'smaller genetic database config (reduced_dbs) or '
                  'full genetic database config  (full_dbs)')
flags.DEFINE_enum('model_preset', 'monomer',
                  ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
                  'Choose preset model configuration - the monomer model, '
                  'the monomer model with extra ensembling, monomer model with '
                  'pTM head, or multimer model')
Augustin-Zidek's avatar
Augustin-Zidek committed
111
112
113
114
115
116
117
118
119
flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
                     'to obtain a timing that excludes the compilation time, '
                     'which should be more indicative of the time required for '
                     'inferencing many proteins.')
flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
                     'pipeline. By default, this is randomly generated. Note '
                     'that even if this is set, Alphafold may still not be '
                     'deterministic, because processes like GPU inference are '
                     'nondeterministic.')
120
121
122
123
124
flags.DEFINE_integer('num_multimer_predictions_per_model', 5, 'How many '
                     'predictions (each with a different random seed) will be '
                     'generated per model. E.g. if this is 2 and there are 5 '
                     'models then there will be 10 predictions per input. '
                     'Note: this FLAG only applies if model_preset=multimer')
125
flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
Augustin Zidek's avatar
Augustin Zidek committed
126
127
128
129
130
131
                     'have been written to disk instead of running the MSA '
                     'tools. The MSA files are looked up in the output '
                     'directory, so it must stay the same between multiple '
                     'runs that are to reuse the MSAs. WARNING: This will not '
                     'check if the sequence, database or configuration have '
                     'changed.')
132
133
134
135
136
137
138
139
140
flags.DEFINE_enum_class('models_to_relax', ModelsToRelax.BEST, ModelsToRelax,
                        'The models to run the final relaxation step on. '
                        'If `all`, all models are relaxed, which may be time '
                        'consuming. If `best`, only the most confident model '
                        'is relaxed. If `none`, relaxation is not run. Turning '
                        'off relaxation might result in predictions with '
                        'distracting stereochemical violations but might help '
                        'in case you are having issues with the relaxation '
                        'stage.')
Augustin Zidek's avatar
Augustin Zidek committed
141
142
143
144
flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
                     'Relax on GPU can be much faster than CPU, so it is '
                     'recommended to enable if possible. GPUs must be available'
                     ' if this setting is enabled.')
145

Augustin-Zidek's avatar
Augustin-Zidek committed
146
147
148
149
150
151
152
FLAGS = flags.FLAGS

MAX_TEMPLATE_HITS = 20
RELAX_MAX_ITERATIONS = 0
RELAX_ENERGY_TOLERANCE = 2.39
RELAX_STIFFNESS = 10.0
RELAX_EXCLUDE_RESIDUES = []
153
RELAX_MAX_OUTER_ITERATIONS = 3
Augustin-Zidek's avatar
Augustin-Zidek committed
154
155


156
157
158
def _check_flag(flag_name: str,
                other_flag_name: str,
                should_be_set: bool):
159
160
  if should_be_set != bool(FLAGS[flag_name].value):
    verb = 'be' if should_be_set else 'not be'
161
162
    raise ValueError(f'{flag_name} must {verb} set when running with '
                     f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
163
164


Hamish Tomlinson's avatar
Hamish Tomlinson committed
165
def _jnp_to_np(output: Dict[str, Any]) -> Dict[str, Any]:
166
167
168
169
170
171
172
173
174
  """Recursively changes jax arrays to numpy arrays."""
  for k, v in output.items():
    if isinstance(v, dict):
      output[k] = _jnp_to_np(v)
    elif isinstance(v, jnp.ndarray):
      output[k] = np.array(v)
  return output


Augustin-Zidek's avatar
Augustin-Zidek committed
175
176
177
178
def predict_structure(
    fasta_path: str,
    fasta_name: str,
    output_dir_base: str,
179
    data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline],
Augustin-Zidek's avatar
Augustin-Zidek committed
180
181
182
    model_runners: Dict[str, model.RunModel],
    amber_relaxer: relax.AmberRelaxation,
    benchmark: bool,
183
184
    random_seed: int,
    models_to_relax: ModelsToRelax):
Augustin-Zidek's avatar
Augustin-Zidek committed
185
  """Predicts structure using AlphaFold for the given sequence."""
186
  logging.info('Predicting %s', fasta_name)
Augustin-Zidek's avatar
Augustin-Zidek committed
187
188
189
190
191
192
193
194
195
196
  timings = {}
  output_dir = os.path.join(output_dir_base, fasta_name)
  if not os.path.exists(output_dir):
    os.makedirs(output_dir)
  msa_output_dir = os.path.join(output_dir, 'msas')
  if not os.path.exists(msa_output_dir):
    os.makedirs(msa_output_dir)

  # Get features.
  t_0 = time.time()
zhuwenwen's avatar
zhuwenwen committed
197
198
199
  # features_output_path = os.path.join(output_dir, 'features.pkl')
  # if os.path.exists(features_output_path):
  #   feature_dict = pickle.load(open(features_output_path, 'rb'))
200
  
zhuwenwen's avatar
zhuwenwen committed
201
202
203
204
  # else:
  feature_dict = data_pipeline.process(
      input_fasta_path=fasta_path,
      msa_output_dir=msa_output_dir)
Augustin-Zidek's avatar
Augustin-Zidek committed
205
206
207
  timings['features'] = time.time() - t_0

  # Write out features as a pickled dictionary.
zhuwenwen's avatar
zhuwenwen committed
208
  features_output_path = os.path.join(output_dir, 'features.pkl')
Augustin-Zidek's avatar
Augustin-Zidek committed
209
210
211
  with open(features_output_path, 'wb') as f:
    pickle.dump(feature_dict, f, protocol=4)

212
  unrelaxed_pdbs = {}
213
  unrelaxed_proteins = {}
Augustin-Zidek's avatar
Augustin-Zidek committed
214
  relaxed_pdbs = {}
Augustin Zidek's avatar
Augustin Zidek committed
215
  relax_metrics = {}
216
  ranking_confidences = {}
Augustin-Zidek's avatar
Augustin-Zidek committed
217
218

  # Run the models.
219
220
221
222
  num_models = len(model_runners)
  for model_index, (model_name, model_runner) in enumerate(
      model_runners.items()):
    logging.info('Running model %s on %s', model_name, fasta_name)
Augustin-Zidek's avatar
Augustin-Zidek committed
223
    t_0 = time.time()
224
    model_random_seed = model_index + random_seed * num_models
Augustin-Zidek's avatar
Augustin-Zidek committed
225
    processed_feature_dict = model_runner.process_features(
226
        feature_dict, random_seed=model_random_seed)
Augustin-Zidek's avatar
Augustin-Zidek committed
227
228
229
    timings[f'process_features_{model_name}'] = time.time() - t_0

    t_0 = time.time()
230
231
    prediction_result = model_runner.predict(processed_feature_dict,
                                             random_seed=model_random_seed)
Augustin-Zidek's avatar
Augustin-Zidek committed
232
233
234
    t_diff = time.time() - t_0
    timings[f'predict_and_compile_{model_name}'] = t_diff
    logging.info(
235
236
        'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs',
        model_name, fasta_name, t_diff)
Augustin-Zidek's avatar
Augustin-Zidek committed
237
238
239

    if benchmark:
      t_0 = time.time()
240
241
242
243
244
245
246
      model_runner.predict(processed_feature_dict,
                           random_seed=model_random_seed)
      t_diff = time.time() - t_0
      timings[f'predict_benchmark_{model_name}'] = t_diff
      logging.info(
          'Total JAX model %s on %s predict time (excludes compilation time): %.1fs',
          model_name, fasta_name, t_diff)
Augustin-Zidek's avatar
Augustin-Zidek committed
247

248
    plddt = prediction_result['plddt']
249
    ranking_confidences[model_name] = prediction_result['ranking_confidence']
Augustin-Zidek's avatar
Augustin-Zidek committed
250

251
252
253
    # Remove jax dependency from results.
    np_prediction_result = _jnp_to_np(dict(prediction_result))

Augustin-Zidek's avatar
Augustin-Zidek committed
254
255
256
    # Save the model outputs.
    result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
    with open(result_output_path, 'wb') as f:
257
      pickle.dump(np_prediction_result, f, protocol=4)
Augustin-Zidek's avatar
Augustin-Zidek committed
258

259
260
261
262
263
264
265
    # Add the predicted LDDT in the b-factor column.
    # Note that higher predicted LDDT value means higher model confidence.
    plddt_b_factors = np.repeat(
        plddt[:, None], residue_constants.atom_type_num, axis=-1)
    unrelaxed_protein = protein.from_prediction(
        features=processed_feature_dict,
        result=prediction_result,
266
267
        b_factors=plddt_b_factors,
        remove_leading_feature_dimension=not model_runner.multimer_mode)
Augustin-Zidek's avatar
Augustin-Zidek committed
268

269
    unrelaxed_proteins[model_name] = unrelaxed_protein
270
    unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
Augustin-Zidek's avatar
Augustin-Zidek committed
271
272
    unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
    with open(unrelaxed_pdb_path, 'w') as f:
273
      f.write(unrelaxed_pdbs[model_name])
Augustin-Zidek's avatar
Augustin-Zidek committed
274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
  # Rank by model confidence.
  ranked_order = [
      model_name for model_name, confidence in
      sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)]

  # Relax predictions.
  if models_to_relax == ModelsToRelax.BEST:
    to_relax = [ranked_order[0]]
  elif models_to_relax == ModelsToRelax.ALL:
    to_relax = ranked_order
  elif models_to_relax == ModelsToRelax.NONE:
    to_relax = []

  for model_name in to_relax:
    t_0 = time.time()
    relaxed_pdb_str, _, violations = amber_relaxer.process(
        prot=unrelaxed_proteins[model_name])
    relax_metrics[model_name] = {
        'remaining_violations': violations,
        'remaining_violations_count': sum(violations)
    }
    timings[f'relax_{model_name}'] = time.time() - t_0

    relaxed_pdbs[model_name] = relaxed_pdb_str

    # Save the relaxed PDB.
    relaxed_output_path = os.path.join(
        output_dir, f'relaxed_{model_name}.pdb')
    with open(relaxed_output_path, 'w') as f:
      f.write(relaxed_pdb_str)

  # Write out relaxed PDBs in rank order.
  for idx, model_name in enumerate(ranked_order):
Augustin-Zidek's avatar
Augustin-Zidek committed
308
309
    ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
    with open(ranked_output_path, 'w') as f:
310
      if model_name in relaxed_pdbs:
311
312
313
        f.write(relaxed_pdbs[model_name])
      else:
        f.write(unrelaxed_pdbs[model_name])
Augustin-Zidek's avatar
Augustin-Zidek committed
314
315
316

  ranking_output_path = os.path.join(output_dir, 'ranking_debug.json')
  with open(ranking_output_path, 'w') as f:
317
318
319
    label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts'
    f.write(json.dumps(
        {label: ranking_confidences, 'order': ranked_order}, indent=4))
Augustin-Zidek's avatar
Augustin-Zidek committed
320
321
322
323
324
325

  logging.info('Final timings for %s: %s', fasta_name, timings)

  timings_output_path = os.path.join(output_dir, 'timings.json')
  with open(timings_output_path, 'w') as f:
    f.write(json.dumps(timings, indent=4))
326
  if models_to_relax != ModelsToRelax.NONE:
Augustin Zidek's avatar
Augustin Zidek committed
327
328
329
    relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json')
    with open(relax_metrics_path, 'w') as f:
      f.write(json.dumps(relax_metrics, indent=4))
Augustin-Zidek's avatar
Augustin-Zidek committed
330
331
332
333
334
335


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

336
337
338
339
340
341
342
343
  for tool_name in (
      'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'):
    if not FLAGS[f'{tool_name}_binary_path'].value:
      raise ValueError(f'Could not find path to the "{tool_name}" binary. Make '
                       'sure it is installed on your system.')

  use_small_bfd = FLAGS.db_preset == 'reduced_dbs'
  _check_flag('small_bfd_database_path', 'db_preset',
344
              should_be_set=use_small_bfd)
345
  _check_flag('bfd_database_path', 'db_preset',
346
              should_be_set=not use_small_bfd)
Augustin Zidek's avatar
Augustin Zidek committed
347
  _check_flag('uniref30_database_path', 'db_preset',
348
349
              should_be_set=not use_small_bfd)

350
351
352
353
354
355
356
357
358
  run_multimer_system = 'multimer' in FLAGS.model_preset
  _check_flag('pdb70_database_path', 'model_preset',
              should_be_set=not run_multimer_system)
  _check_flag('pdb_seqres_database_path', 'model_preset',
              should_be_set=run_multimer_system)
  _check_flag('uniprot_database_path', 'model_preset',
              should_be_set=run_multimer_system)

  if FLAGS.model_preset == 'monomer_casp14':
Augustin-Zidek's avatar
Augustin-Zidek committed
359
    num_ensemble = 8
360
361
  else:
    num_ensemble = 1
Augustin-Zidek's avatar
Augustin-Zidek committed
362
363
364
365
366
367

  # Check for duplicate FASTA file names.
  fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths]
  if len(fasta_names) != len(set(fasta_names)):
    raise ValueError('All FASTA paths must have a unique basename.')

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
  if run_multimer_system:
    template_searcher = hmmsearch.Hmmsearch(
        binary_path=FLAGS.hmmsearch_binary_path,
        hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
        database_path=FLAGS.pdb_seqres_database_path)
    template_featurizer = templates.HmmsearchHitFeaturizer(
        mmcif_dir=FLAGS.template_mmcif_dir,
        max_template_date=FLAGS.max_template_date,
        max_hits=MAX_TEMPLATE_HITS,
        kalign_binary_path=FLAGS.kalign_binary_path,
        release_dates_path=None,
        obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)
  else:
    template_searcher = hhsearch.HHSearch(
        binary_path=FLAGS.hhsearch_binary_path,
        databases=[FLAGS.pdb70_database_path])
    template_featurizer = templates.HhsearchHitFeaturizer(
        mmcif_dir=FLAGS.template_mmcif_dir,
        max_template_date=FLAGS.max_template_date,
        max_hits=MAX_TEMPLATE_HITS,
        kalign_binary_path=FLAGS.kalign_binary_path,
        release_dates_path=None,
        obsolete_pdbs_path=FLAGS.obsolete_pdbs_path)

  monomer_data_pipeline = pipeline.DataPipeline(
Augustin-Zidek's avatar
Augustin-Zidek committed
393
394
395
396
397
      jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
      hhblits_binary_path=FLAGS.hhblits_binary_path,
      uniref90_database_path=FLAGS.uniref90_database_path,
      mgnify_database_path=FLAGS.mgnify_database_path,
      bfd_database_path=FLAGS.bfd_database_path,
Augustin Zidek's avatar
Augustin Zidek committed
398
      uniref30_database_path=FLAGS.uniref30_database_path,
399
      small_bfd_database_path=FLAGS.small_bfd_database_path,
400
      template_searcher=template_searcher,
401
      template_featurizer=template_featurizer,
402
403
404
405
      use_small_bfd=use_small_bfd,
      use_precomputed_msas=FLAGS.use_precomputed_msas)

  if run_multimer_system:
406
    num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
407
408
409
410
411
412
    data_pipeline = pipeline_multimer.DataPipeline(
        monomer_data_pipeline=monomer_data_pipeline,
        jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
        uniprot_database_path=FLAGS.uniprot_database_path,
        use_precomputed_msas=FLAGS.use_precomputed_msas)
  else:
413
    num_predictions_per_model = 1
414
    data_pipeline = monomer_data_pipeline
Augustin-Zidek's avatar
Augustin-Zidek committed
415
416

  model_runners = {}
zhuwenwen's avatar
zhuwenwen committed
417
418
  # model_names = config.MODEL_PRESETS[FLAGS.model_preset]
  model_names = FLAGS.model_names
419
  for model_name in model_names:
Augustin-Zidek's avatar
Augustin-Zidek committed
420
    model_config = config.model_config(model_name)
421
422
423
424
    if run_multimer_system:
      model_config.model.num_ensemble_eval = num_ensemble
    else:
      model_config.data.eval.num_ensemble = num_ensemble
Augustin-Zidek's avatar
Augustin-Zidek committed
425
426
427
    model_params = data.get_model_haiku_params(
        model_name=model_name, data_dir=FLAGS.data_dir)
    model_runner = model.RunModel(model_config, model_params)
428
429
    for i in range(num_predictions_per_model):
      model_runners[f'{model_name}_pred_{i}'] = model_runner
Augustin-Zidek's avatar
Augustin-Zidek committed
430
431
432
433

  logging.info('Have %d models: %s', len(model_runners),
               list(model_runners.keys()))

434
435
436
437
438
439
440
  amber_relaxer = relax.AmberRelaxation(
      max_iterations=RELAX_MAX_ITERATIONS,
      tolerance=RELAX_ENERGY_TOLERANCE,
      stiffness=RELAX_STIFFNESS,
      exclude_residues=RELAX_EXCLUDE_RESIDUES,
      max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
      use_gpu=FLAGS.use_gpu_relax)
Augustin-Zidek's avatar
Augustin-Zidek committed
441
442
443

  random_seed = FLAGS.random_seed
  if random_seed is None:
444
    random_seed = random.randrange(sys.maxsize // len(model_runners))
Augustin-Zidek's avatar
Augustin-Zidek committed
445
446
447
  logging.info('Using random seed %d for the data pipeline', random_seed)

  # Predict structure for each of the sequences.
448
449
  for i, fasta_path in enumerate(FLAGS.fasta_paths):
    fasta_name = fasta_names[i]
Augustin-Zidek's avatar
Augustin-Zidek committed
450
451
452
453
454
455
456
457
    predict_structure(
        fasta_path=fasta_path,
        fasta_name=fasta_name,
        output_dir_base=FLAGS.output_dir,
        data_pipeline=data_pipeline,
        model_runners=model_runners,
        amber_relaxer=amber_relaxer,
        benchmark=FLAGS.benchmark,
458
459
        random_seed=random_seed,
        models_to_relax=FLAGS.models_to_relax)
Augustin-Zidek's avatar
Augustin-Zidek committed
460
461
462
463
464
465
466


if __name__ == '__main__':
  flags.mark_flags_as_required([
      'fasta_paths',
      'output_dir',
      'data_dir',
zhuwenwen's avatar
zhuwenwen committed
467
      'model_names',
Augustin-Zidek's avatar
Augustin-Zidek committed
468
469
470
471
472
      'uniref90_database_path',
      'mgnify_database_path',
      'template_mmcif_dir',
      'max_template_date',
      'obsolete_pdbs_path',
Augustin Zidek's avatar
Augustin Zidek committed
473
      'use_gpu_relax',
Augustin-Zidek's avatar
Augustin-Zidek committed
474
475
476
  ])

  app.run(main)