Commit 7d53297c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update pretrained runner

parent 2065a6ca
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# limitations under the License. # limitations under the License.
import os import os
#import sys
#sys.path.append("lib/conda/lib/python3.9/site-packages") # A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import math import math
import pickle import pickle
...@@ -28,10 +29,6 @@ from config import model_config ...@@ -28,10 +29,6 @@ from config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
#os.environ["OPENMM_DEFAULT_PLATFORM"] = "CPU"
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
#os.environ["OPENMM_CPU_THREADS"] = "16"
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
...@@ -43,10 +40,9 @@ from openfold.utils.tensor_utils import ( ...@@ -43,10 +40,9 @@ from openfold.utils.tensor_utils import (
MODEL_NAME = "model_1" MODEL_NAME = "model_1"
MODEL_DEVICE = "cuda:1" MODEL_DEVICE = "cuda:4"
PARAM_PATH = "openfold/resources/params/params_model_1.npz" PARAM_PATH = "openfold/resources/params/params_model_1.npz"
#FEAT_PATH = "tests/test_data/sample_feats.pickle" FEAT_PATH = "tests/test_data/sample_feats.pickle"
FEAT_PATH = "prediction/1OJN_feats.pickle"
config = model_config(MODEL_NAME) config = model_config(MODEL_NAME)
model = AlphaFold(config.model) model = AlphaFold(config.model)
...@@ -66,6 +62,8 @@ with torch.no_grad(): ...@@ -66,6 +62,8 @@ with torch.no_grad():
"extra_msa", "extra_msa",
"residx_atom37_to_atom14", "residx_atom37_to_atom14",
"residx_atom14_to_atom37", "residx_atom14_to_atom37",
"true_msa",
"residue_index",
] ]
for l in longs: for l in longs:
batch[l] = batch[l].long() batch[l] = batch[l].long()
...@@ -95,6 +93,8 @@ unrelaxed_protein = protein.from_prediction( ...@@ -95,6 +93,8 @@ unrelaxed_protein = protein.from_prediction(
b_factors=plddt_b_factors b_factors=plddt_b_factors
) )
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
**config.relax **config.relax
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment