Commit 91ac85ac authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Remove jax dependency from results pkl

PiperOrigin-RevId: 502584418
Change-Id: I8ab14363a6342726691213bbeb180c0a8ff7f932
parent 684ffa19
...@@ -22,7 +22,7 @@ import random ...@@ -22,7 +22,7 @@ import random
import shutil import shutil
import sys import sys
import time import time
from typing import Dict, Union from typing import Any, Dict, Mapping, Union
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -38,6 +38,7 @@ from alphafold.model import config ...@@ -38,6 +38,7 @@ from alphafold.model import config
from alphafold.model import data from alphafold.model import data
from alphafold.model import model from alphafold.model import model
from alphafold.relax import relax from alphafold.relax import relax
import jax.numpy as jnp
import numpy as np import numpy as np
# Internal import (7716). # Internal import (7716).
...@@ -160,6 +161,16 @@ def _check_flag(flag_name: str, ...@@ -160,6 +161,16 @@ def _check_flag(flag_name: str,
f'"--{other_flag_name}={FLAGS[other_flag_name].value}".') f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
def _jnp_to_np(output: dict[str, Any]) -> dict[str, Any]:
"""Recursively changes jax arrays to numpy arrays."""
for k, v in output.items():
if isinstance(v, dict):
output[k] = _jnp_to_np(v)
elif isinstance(v, jnp.ndarray):
output[k] = np.array(v)
return output
def predict_structure( def predict_structure(
fasta_path: str, fasta_path: str,
fasta_name: str, fasta_name: str,
...@@ -231,10 +242,13 @@ def predict_structure( ...@@ -231,10 +242,13 @@ def predict_structure(
plddt = prediction_result['plddt'] plddt = prediction_result['plddt']
ranking_confidences[model_name] = prediction_result['ranking_confidence'] ranking_confidences[model_name] = prediction_result['ranking_confidence']
# Remove jax dependency from results.
np_prediction_result = _jnp_to_np(dict(prediction_result))
# Save the model outputs. # Save the model outputs.
result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
with open(result_output_path, 'wb') as f: with open(result_output_path, 'wb') as f:
pickle.dump(prediction_result, f, protocol=4) pickle.dump(np_prediction_result, f, protocol=4)
# Add the predicted LDDT in the b-factor column. # Add the predicted LDDT in the b-factor column.
# Note that higher predicted LDDT value means higher model confidence. # Note that higher predicted LDDT value means higher model confidence.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment