Commit eb49d157 authored by Rich Evans's avatar Rich Evans Committed by Copybara-Service
Browse files

Enable multiple seeds per model for AlphaFold-Multimer.

PiperOrigin-RevId: 429318717
Change-Id: Ib07df7ecd4bc52feea80a66c585ce51f332ac7a1
parent 2cd61ade
......@@ -318,6 +318,11 @@ python3 docker/run_docker.py \
--data_dir=$DOWNLOAD_DIR
```
By default the multimer system will run 5 seeds per model (25 total predictions)
for a small drop in accuracy you may wish to run a single seed per model. This
can be done via the `--num_multimer_predictions_per_model` flag, e.g. set it to
`--num_multimer_predictions_per_model=1` to run a single seed per model.
### Examples
Below are examples on how to use AlphaFold in different scenarios.
......
......@@ -73,6 +73,11 @@ flags.DEFINE_enum(
['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'],
'Choose preset model configuration - the monomer model, the monomer model '
'with extra ensembling, monomer model with pTM head, or multimer model')
flags.DEFINE_integer('num_multimer_predictions_per_model', 5, 'How many '
'predictions (each with a different random seed) will be '
'generated per model. E.g. if this is 2 and there are 5 '
'models then there will be 10 predictions per input. '
'Note: this FLAG only applies if model_preset=multimer')
flags.DEFINE_boolean(
'benchmark', False,
'Run multiple JAX model evaluations to obtain a timing that excludes the '
......@@ -213,6 +218,7 @@ def main(argv):
f'--model_preset={FLAGS.model_preset}',
f'--benchmark={FLAGS.benchmark}',
f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}',
f'--run_relax={FLAGS.run_relax}',
f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr',
......
......@@ -113,6 +113,11 @@ flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
'that even if this is set, Alphafold may still not be '
'deterministic, because processes like GPU inference are '
'nondeterministic.')
flags.DEFINE_integer('num_multimer_predictions_per_model', 5, 'How many '
'predictions (each with a different random seed) will be '
'generated per model. E.g. if this is 2 and there are 5 '
'models then there will be 10 predictions per input. '
'Note: this FLAG only applies if model_preset=multimer')
flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
'have been written to disk instead of running the MSA '
'tools. The MSA files are looked up in the output '
......@@ -373,12 +378,14 @@ def main(argv):
use_precomputed_msas=FLAGS.use_precomputed_msas)
if run_multimer_system:
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
data_pipeline = pipeline_multimer.DataPipeline(
monomer_data_pipeline=monomer_data_pipeline,
jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
uniprot_database_path=FLAGS.uniprot_database_path,
use_precomputed_msas=FLAGS.use_precomputed_msas)
else:
num_predictions_per_model = 1
data_pipeline = monomer_data_pipeline
model_runners = {}
......@@ -392,7 +399,8 @@ def main(argv):
model_params = data.get_model_haiku_params(
model_name=model_name, data_dir=FLAGS.data_dir)
model_runner = model.RunModel(model_config, model_params)
model_runners[model_name] = model_runner
for i in range(num_predictions_per_model):
model_runners[f'{model_name}_pred_{i}'] = model_runner
logging.info('Have %d models: %s', len(model_runners),
list(model_runners.keys()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment