Commit 684ffa19 authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Add ability to only run relax for the best unrelaxed model.

PiperOrigin-RevId: 501851892
Change-Id: I7484c2aa7ac30af611d88cfa8d632096c262824a
parent 0d9a24b2
......@@ -28,12 +28,15 @@ from docker import types
flags.DEFINE_bool(
'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
flags.DEFINE_boolean(
'run_relax', True,
'Whether to run the final relaxation step on the predicted models. Turning '
'relax off might result in predictions with distracting stereochemical '
'violations but might help in case you are having issues with the '
'relaxation stage.')
flags.DEFINE_enum('models_to_relax', 'best', ['best', 'all', 'none'],
'The models to run the final relaxation step on. '
'If `all`, all models are relaxed, which may be time '
'consuming. If `best`, only the most confident model is '
'relaxed. If `none`, relaxation is not run. Turning off '
'relaxation might result in predictions with '
'distracting stereochemical violations but might help '
'in case you are having issues with the relaxation '
'stage.')
flags.DEFINE_bool(
'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.')
flags.DEFINE_string(
......@@ -221,7 +224,7 @@ def main(argv):
f'--benchmark={FLAGS.benchmark}',
f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}',
f'--run_relax={FLAGS.run_relax}',
f'--models_to_relax={FLAGS.models_to_relax}',
f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr',
])
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Full AlphaFold protein structure prediction script."""
import enum
import json
import os
import pathlib
......@@ -43,6 +44,13 @@ import numpy as np
logging.set_verbosity(logging.INFO)
@enum.unique
class ModelsToRelax(enum.Enum):
ALL = 0
BEST = 1
NONE = 2
flags.DEFINE_list(
'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
'target that will be folded one after another. If a FASTA file contains '
......@@ -119,11 +127,15 @@ flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that '
'runs that are to reuse the MSAs. WARNING: This will not '
'check if the sequence, database or configuration have '
'changed.')
flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation '
'step on the predicted models. Turning relax off might '
'result in predictions with distracting stereochemical '
'violations but might help in case you are having issues '
'with the relaxation stage.')
flags.DEFINE_enum_class('models_to_relax', ModelsToRelax.BEST, ModelsToRelax,
'The models to run the final relaxation step on. '
'If `all`, all models are relaxed, which may be time '
'consuming. If `best`, only the most confident model '
'is relaxed. If `none`, relaxation is not run. Turning '
'off relaxation might result in predictions with '
'distracting stereochemical violations but might help '
'in case you are having issues with the relaxation '
'stage.')
flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
'Relax on GPU can be much faster than CPU, so it is '
'recommended to enable if possible. GPUs must be available'
......@@ -156,7 +168,8 @@ def predict_structure(
model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation,
benchmark: bool,
random_seed: int):
random_seed: int,
models_to_relax: ModelsToRelax):
"""Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name)
timings = {}
......@@ -180,6 +193,7 @@ def predict_structure(
pickle.dump(feature_dict, f, protocol=4)
unrelaxed_pdbs = {}
unrelaxed_proteins = {}
relaxed_pdbs = {}
relax_metrics = {}
ranking_confidences = {}
......@@ -232,38 +246,48 @@ def predict_structure(
b_factors=plddt_b_factors,
remove_leading_feature_dimension=not model_runner.multimer_mode)
unrelaxed_proteins[model_name] = unrelaxed_protein
unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
with open(unrelaxed_pdb_path, 'w') as f:
f.write(unrelaxed_pdbs[model_name])
if amber_relaxer:
# Relax the prediction.
t_0 = time.time()
relaxed_pdb_str, _, violations = amber_relaxer.process(
prot=unrelaxed_protein)
relax_metrics[model_name] = {
'remaining_violations': violations,
'remaining_violations_count': sum(violations)
}
timings[f'relax_{model_name}'] = time.time() - t_0
relaxed_pdbs[model_name] = relaxed_pdb_str
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_dir, f'relaxed_{model_name}.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
# Rank by model confidence and write out relaxed PDBs in rank order.
ranked_order = []
for idx, (model_name, _) in enumerate(
sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
ranked_order.append(model_name)
# Rank by model confidence.
ranked_order = [
model_name for model_name, confidence in
sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)]
# Relax predictions.
if models_to_relax == ModelsToRelax.BEST:
to_relax = [ranked_order[0]]
elif models_to_relax == ModelsToRelax.ALL:
to_relax = ranked_order
elif models_to_relax == ModelsToRelax.NONE:
to_relax = []
for model_name in to_relax:
t_0 = time.time()
relaxed_pdb_str, _, violations = amber_relaxer.process(
prot=unrelaxed_proteins[model_name])
relax_metrics[model_name] = {
'remaining_violations': violations,
'remaining_violations_count': sum(violations)
}
timings[f'relax_{model_name}'] = time.time() - t_0
relaxed_pdbs[model_name] = relaxed_pdb_str
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_dir, f'relaxed_{model_name}.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
# Write out relaxed PDBs in rank order.
for idx, model_name in enumerate(ranked_order):
ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
with open(ranked_output_path, 'w') as f:
if amber_relaxer:
if model_name in relaxed_pdbs:
f.write(relaxed_pdbs[model_name])
else:
f.write(unrelaxed_pdbs[model_name])
......@@ -279,7 +303,7 @@ def predict_structure(
timings_output_path = os.path.join(output_dir, 'timings.json')
with open(timings_output_path, 'w') as f:
f.write(json.dumps(timings, indent=4))
if amber_relaxer:
if models_to_relax != ModelsToRelax.NONE:
relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json')
with open(relax_metrics_path, 'w') as f:
f.write(json.dumps(relax_metrics, indent=4))
......@@ -386,16 +410,13 @@ def main(argv):
logging.info('Have %d models: %s', len(model_runners),
list(model_runners.keys()))
if FLAGS.run_relax:
amber_relaxer = relax.AmberRelaxation(
max_iterations=RELAX_MAX_ITERATIONS,
tolerance=RELAX_ENERGY_TOLERANCE,
stiffness=RELAX_STIFFNESS,
exclude_residues=RELAX_EXCLUDE_RESIDUES,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=FLAGS.use_gpu_relax)
else:
amber_relaxer = None
amber_relaxer = relax.AmberRelaxation(
max_iterations=RELAX_MAX_ITERATIONS,
tolerance=RELAX_ENERGY_TOLERANCE,
stiffness=RELAX_STIFFNESS,
exclude_residues=RELAX_EXCLUDE_RESIDUES,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=FLAGS.use_gpu_relax)
random_seed = FLAGS.random_seed
if random_seed is None:
......@@ -413,7 +434,8 @@ def main(argv):
model_runners=model_runners,
amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark,
random_seed=random_seed)
random_seed=random_seed,
models_to_relax=FLAGS.models_to_relax)
if __name__ == '__main__':
......
......@@ -28,10 +28,10 @@ import numpy as np
class RunAlphafoldTest(parameterized.TestCase):
@parameterized.named_parameters(
('relax', True),
('no_relax', False),
('relax', run_alphafold.ModelsToRelax.ALL),
('no_relax', run_alphafold.ModelsToRelax.NONE),
)
def test_end_to_end(self, do_relax):
def test_end_to_end(self, models_to_relax):
data_pipeline_mock = mock.Mock()
model_runner_mock = mock.Mock()
......@@ -72,9 +72,11 @@ class RunAlphafoldTest(parameterized.TestCase):
output_dir_base=out_dir,
data_pipeline=data_pipeline_mock,
model_runners={'model1': model_runner_mock},
amber_relaxer=amber_relaxer_mock if do_relax else None,
amber_relaxer=amber_relaxer_mock,
benchmark=False,
random_seed=0)
random_seed=0,
models_to_relax=models_to_relax,
)
base_output_files = os.listdir(out_dir)
self.assertIn('target.fasta', base_output_files)
......@@ -85,7 +87,7 @@ class RunAlphafoldTest(parameterized.TestCase):
'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json',
'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb',
]
if do_relax:
if models_to_relax == run_alphafold.ModelsToRelax.ALL:
expected_files.extend(['relaxed_model1.pdb', 'relax_metrics.json'])
with open(os.path.join(out_dir, 'test', 'relax_metrics.json')) as f:
relax_metrics = json.loads(f.read())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment