Commit 684ffa19 authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Add ability to only run relax for the best unrelaxed model.

PiperOrigin-RevId: 501851892
Change-Id: I7484c2aa7ac30af611d88cfa8d632096c262824a
parent 0d9a24b2
...@@ -28,12 +28,15 @@ from docker import types ...@@ -28,12 +28,15 @@ from docker import types
flags.DEFINE_bool( flags.DEFINE_bool(
'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') 'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
flags.DEFINE_boolean( flags.DEFINE_enum('models_to_relax', 'best', ['best', 'all', 'none'],
'run_relax', True, 'The models to run the final relaxation step on. '
'Whether to run the final relaxation step on the predicted models. Turning ' 'If `all`, all models are relaxed, which may be time '
'relax off might result in predictions with distracting stereochemical ' 'consuming. If `best`, only the most confident model is '
'violations but might help in case you are having issues with the ' 'relaxed. If `none`, relaxation is not run. Turning off '
'relaxation stage.') 'relaxation might result in predictions with '
'distracting stereochemical violations but might help '
'in case you are having issues with the relaxation '
'stage.')
flags.DEFINE_bool( flags.DEFINE_bool(
'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.') 'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.')
flags.DEFINE_string( flags.DEFINE_string(
...@@ -221,7 +224,7 @@ def main(argv): ...@@ -221,7 +224,7 @@ def main(argv):
f'--benchmark={FLAGS.benchmark}', f'--benchmark={FLAGS.benchmark}',
f'--use_precomputed_msas={FLAGS.use_precomputed_msas}', f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}', f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}',
f'--run_relax={FLAGS.run_relax}', f'--models_to_relax={FLAGS.models_to_relax}',
f'--use_gpu_relax={use_gpu_relax}', f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr', '--logtostderr',
]) ])
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Full AlphaFold protein structure prediction script.""" """Full AlphaFold protein structure prediction script."""
import enum
import json import json
import os import os
import pathlib import pathlib
...@@ -43,6 +44,13 @@ import numpy as np ...@@ -43,6 +44,13 @@ import numpy as np
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
@enum.unique
class ModelsToRelax(enum.Enum):
ALL = 0
BEST = 1
NONE = 2
flags.DEFINE_list( flags.DEFINE_list(
'fasta_paths', None, 'Paths to FASTA files, each containing a prediction ' 'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
'target that will be folded one after another. If a FASTA file contains ' 'target that will be folded one after another. If a FASTA file contains '
...@@ -119,11 +127,15 @@ flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that ' ...@@ -119,11 +127,15 @@ flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
'runs that are to reuse the MSAs. WARNING: This will not ' 'runs that are to reuse the MSAs. WARNING: This will not '
'check if the sequence, database or configuration have ' 'check if the sequence, database or configuration have '
'changed.') 'changed.')
flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation ' flags.DEFINE_enum_class('models_to_relax', ModelsToRelax.BEST, ModelsToRelax,
'step on the predicted models. Turning relax off might ' 'The models to run the final relaxation step on. '
'result in predictions with distracting stereochemical ' 'If `all`, all models are relaxed, which may be time '
'violations but might help in case you are having issues ' 'consuming. If `best`, only the most confident model '
'with the relaxation stage.') 'is relaxed. If `none`, relaxation is not run. Turning '
'off relaxation might result in predictions with '
'distracting stereochemical violations but might help '
'in case you are having issues with the relaxation '
'stage.')
flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. ' flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
'Relax on GPU can be much faster than CPU, so it is ' 'Relax on GPU can be much faster than CPU, so it is '
'recommended to enable if possible. GPUs must be available' 'recommended to enable if possible. GPUs must be available'
...@@ -156,7 +168,8 @@ def predict_structure( ...@@ -156,7 +168,8 @@ def predict_structure(
model_runners: Dict[str, model.RunModel], model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation, amber_relaxer: relax.AmberRelaxation,
benchmark: bool, benchmark: bool,
random_seed: int): random_seed: int,
models_to_relax: ModelsToRelax):
"""Predicts structure using AlphaFold for the given sequence.""" """Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name) logging.info('Predicting %s', fasta_name)
timings = {} timings = {}
...@@ -180,6 +193,7 @@ def predict_structure( ...@@ -180,6 +193,7 @@ def predict_structure(
pickle.dump(feature_dict, f, protocol=4) pickle.dump(feature_dict, f, protocol=4)
unrelaxed_pdbs = {} unrelaxed_pdbs = {}
unrelaxed_proteins = {}
relaxed_pdbs = {} relaxed_pdbs = {}
relax_metrics = {} relax_metrics = {}
ranking_confidences = {} ranking_confidences = {}
...@@ -232,38 +246,48 @@ def predict_structure( ...@@ -232,38 +246,48 @@ def predict_structure(
b_factors=plddt_b_factors, b_factors=plddt_b_factors,
remove_leading_feature_dimension=not model_runner.multimer_mode) remove_leading_feature_dimension=not model_runner.multimer_mode)
unrelaxed_proteins[model_name] = unrelaxed_protein
unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein) unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
with open(unrelaxed_pdb_path, 'w') as f: with open(unrelaxed_pdb_path, 'w') as f:
f.write(unrelaxed_pdbs[model_name]) f.write(unrelaxed_pdbs[model_name])
if amber_relaxer: # Rank by model confidence.
# Relax the prediction. ranked_order = [
t_0 = time.time() model_name for model_name, confidence in
relaxed_pdb_str, _, violations = amber_relaxer.process( sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)]
prot=unrelaxed_protein)
relax_metrics[model_name] = { # Relax predictions.
'remaining_violations': violations, if models_to_relax == ModelsToRelax.BEST:
'remaining_violations_count': sum(violations) to_relax = [ranked_order[0]]
} elif models_to_relax == ModelsToRelax.ALL:
timings[f'relax_{model_name}'] = time.time() - t_0 to_relax = ranked_order
elif models_to_relax == ModelsToRelax.NONE:
relaxed_pdbs[model_name] = relaxed_pdb_str to_relax = []
# Save the relaxed PDB. for model_name in to_relax:
relaxed_output_path = os.path.join( t_0 = time.time()
output_dir, f'relaxed_{model_name}.pdb') relaxed_pdb_str, _, violations = amber_relaxer.process(
with open(relaxed_output_path, 'w') as f: prot=unrelaxed_proteins[model_name])
f.write(relaxed_pdb_str) relax_metrics[model_name] = {
'remaining_violations': violations,
# Rank by model confidence and write out relaxed PDBs in rank order. 'remaining_violations_count': sum(violations)
ranked_order = [] }
for idx, (model_name, _) in enumerate( timings[f'relax_{model_name}'] = time.time() - t_0
sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
ranked_order.append(model_name) relaxed_pdbs[model_name] = relaxed_pdb_str
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_dir, f'relaxed_{model_name}.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
# Write out relaxed PDBs in rank order.
for idx, model_name in enumerate(ranked_order):
ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
with open(ranked_output_path, 'w') as f: with open(ranked_output_path, 'w') as f:
if amber_relaxer: if model_name in relaxed_pdbs:
f.write(relaxed_pdbs[model_name]) f.write(relaxed_pdbs[model_name])
else: else:
f.write(unrelaxed_pdbs[model_name]) f.write(unrelaxed_pdbs[model_name])
...@@ -279,7 +303,7 @@ def predict_structure( ...@@ -279,7 +303,7 @@ def predict_structure(
timings_output_path = os.path.join(output_dir, 'timings.json') timings_output_path = os.path.join(output_dir, 'timings.json')
with open(timings_output_path, 'w') as f: with open(timings_output_path, 'w') as f:
f.write(json.dumps(timings, indent=4)) f.write(json.dumps(timings, indent=4))
if amber_relaxer: if models_to_relax != ModelsToRelax.NONE:
relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json') relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json')
with open(relax_metrics_path, 'w') as f: with open(relax_metrics_path, 'w') as f:
f.write(json.dumps(relax_metrics, indent=4)) f.write(json.dumps(relax_metrics, indent=4))
...@@ -386,16 +410,13 @@ def main(argv): ...@@ -386,16 +410,13 @@ def main(argv):
logging.info('Have %d models: %s', len(model_runners), logging.info('Have %d models: %s', len(model_runners),
list(model_runners.keys())) list(model_runners.keys()))
if FLAGS.run_relax: amber_relaxer = relax.AmberRelaxation(
amber_relaxer = relax.AmberRelaxation( max_iterations=RELAX_MAX_ITERATIONS,
max_iterations=RELAX_MAX_ITERATIONS, tolerance=RELAX_ENERGY_TOLERANCE,
tolerance=RELAX_ENERGY_TOLERANCE, stiffness=RELAX_STIFFNESS,
stiffness=RELAX_STIFFNESS, exclude_residues=RELAX_EXCLUDE_RESIDUES,
exclude_residues=RELAX_EXCLUDE_RESIDUES, max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS, use_gpu=FLAGS.use_gpu_relax)
use_gpu=FLAGS.use_gpu_relax)
else:
amber_relaxer = None
random_seed = FLAGS.random_seed random_seed = FLAGS.random_seed
if random_seed is None: if random_seed is None:
...@@ -413,7 +434,8 @@ def main(argv): ...@@ -413,7 +434,8 @@ def main(argv):
model_runners=model_runners, model_runners=model_runners,
amber_relaxer=amber_relaxer, amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark, benchmark=FLAGS.benchmark,
random_seed=random_seed) random_seed=random_seed,
models_to_relax=FLAGS.models_to_relax)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,10 +28,10 @@ import numpy as np ...@@ -28,10 +28,10 @@ import numpy as np
class RunAlphafoldTest(parameterized.TestCase): class RunAlphafoldTest(parameterized.TestCase):
@parameterized.named_parameters( @parameterized.named_parameters(
('relax', True), ('relax', run_alphafold.ModelsToRelax.ALL),
('no_relax', False), ('no_relax', run_alphafold.ModelsToRelax.NONE),
) )
def test_end_to_end(self, do_relax): def test_end_to_end(self, models_to_relax):
data_pipeline_mock = mock.Mock() data_pipeline_mock = mock.Mock()
model_runner_mock = mock.Mock() model_runner_mock = mock.Mock()
...@@ -72,9 +72,11 @@ class RunAlphafoldTest(parameterized.TestCase): ...@@ -72,9 +72,11 @@ class RunAlphafoldTest(parameterized.TestCase):
output_dir_base=out_dir, output_dir_base=out_dir,
data_pipeline=data_pipeline_mock, data_pipeline=data_pipeline_mock,
model_runners={'model1': model_runner_mock}, model_runners={'model1': model_runner_mock},
amber_relaxer=amber_relaxer_mock if do_relax else None, amber_relaxer=amber_relaxer_mock,
benchmark=False, benchmark=False,
random_seed=0) random_seed=0,
models_to_relax=models_to_relax,
)
base_output_files = os.listdir(out_dir) base_output_files = os.listdir(out_dir)
self.assertIn('target.fasta', base_output_files) self.assertIn('target.fasta', base_output_files)
...@@ -85,7 +87,7 @@ class RunAlphafoldTest(parameterized.TestCase): ...@@ -85,7 +87,7 @@ class RunAlphafoldTest(parameterized.TestCase):
'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json',
'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb', 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb',
] ]
if do_relax: if models_to_relax == run_alphafold.ModelsToRelax.ALL:
expected_files.extend(['relaxed_model1.pdb', 'relax_metrics.json']) expected_files.extend(['relaxed_model1.pdb', 'relax_metrics.json'])
with open(os.path.join(out_dir, 'test', 'relax_metrics.json')) as f: with open(os.path.join(out_dir, 'test', 'relax_metrics.json')) as f:
relax_metrics = json.loads(f.read()) relax_metrics = json.loads(f.read())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment