Commit 91ac85ac authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Remove jax dependency from results pkl

PiperOrigin-RevId: 502584418
Change-Id: I8ab14363a6342726691213bbeb180c0a8ff7f932
parent 684ffa19
......@@ -22,7 +22,7 @@ import random
import shutil
import sys
import time
from typing import Dict, Union
from typing import Any, Dict, Mapping, Union
from absl import app
from absl import flags
......@@ -38,6 +38,7 @@ from alphafold.model import config
from alphafold.model import data
from alphafold.model import model
from alphafold.relax import relax
import jax.numpy as jnp
import numpy as np
# Internal import (7716).
......@@ -160,6 +161,16 @@ def _check_flag(flag_name: str,
f'"--{other_flag_name}={FLAGS[other_flag_name].value}".')
def _jnp_to_np(output: dict[str, Any]) -> dict[str, Any]:
"""Recursively changes jax arrays to numpy arrays."""
for k, v in output.items():
if isinstance(v, dict):
output[k] = _jnp_to_np(v)
elif isinstance(v, jnp.ndarray):
output[k] = np.array(v)
return output
def predict_structure(
fasta_path: str,
fasta_name: str,
......@@ -231,10 +242,13 @@ def predict_structure(
plddt = prediction_result['plddt']
ranking_confidences[model_name] = prediction_result['ranking_confidence']
# Remove jax dependency from results.
np_prediction_result = _jnp_to_np(dict(prediction_result))
# Save the model outputs.
result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl')
with open(result_output_path, 'wb') as f:
pickle.dump(prediction_result, f, protocol=4)
pickle.dump(np_prediction_result, f, protocol=4)
# Add the predicted LDDT in the b-factor column.
# Note that higher predicted LDDT value means higher model confidence.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment