Commit 54317fe4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add pTM config options, spruce up pretrained script, add test script

parent 3fb44548
......@@ -2,38 +2,72 @@ import copy
import ml_collections as mlc
def model_config(name):
def model_config(name, train=False):
c = copy.deepcopy(config)
if(name == "model_3"):
if(name == "model_1"):
pass
elif(name == "model_2"):
pass
elif(name == "model_3"):
c.model.template.enabled = False
elif(name == "model_4"):
c.model.template.enabled = False
elif(name == "model_5"):
c.model.template.enabled = False
elif(name == "model_1_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_2_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_3_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_4_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_5_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
return c
if(train):
c.globals.model.blocks_per_ckpt = 1
c.globals.chunk_size = None
return c
c_z = mlc.FieldReference(128)
c_m = mlc.FieldReference(256)
c_t = mlc.FieldReference(64)
c_e = mlc.FieldReference(64)
c_s = mlc.FieldReference(384)
blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64)
eps = 1e-8
inf = 1e8
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
inf = mlc.FieldReference(1e8, field_type=float)
config = mlc.ConfigDict({
"model": {
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"no_cycles": 2,#4,
"eps": eps,
"inf": inf,
},
"model": {
"no_cycles": 4,
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
......@@ -147,7 +181,7 @@ config = mlc.ConfigDict({
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 2,#8,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
......@@ -168,7 +202,7 @@ config = mlc.ConfigDict({
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": True,
"enabled": False,
},
"masked_msa": {
"c_m": c_m,
......@@ -245,7 +279,7 @@ config = mlc.ConfigDict({
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps,#1e-8,
"weight": 1.0,
"weight": 0.,
},
"eps": eps,
},
......
......@@ -202,12 +202,12 @@ class AlphaFold(nn.Module):
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m),
(*batch_dims, n, self.config.input_embedder.c_m),
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z),
(*batch_dims, n, n, self.config.input_embedder.c_z),
)
# [*, N, 3]
......
......@@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import pickle
import os
# A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import math
import pickle
import time
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from config import model_config
from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax
from openfold.utils.import_weights import (
import_jax_weights_,
......@@ -38,23 +39,23 @@ from openfold.utils.tensor_utils import (
tensor_tree_map,
)
MODEL_NAME = "model_1"
MODEL_DEVICE = "cuda:4"
PARAM_PATH = "openfold/resources/params/params_model_1.npz"
FEAT_PATH = "tests/test_data/sample_feats.pickle"
config = model_config(MODEL_NAME)
model = AlphaFold(config.model)
model = model.eval()
import_jax_weights_(model, PARAM_PATH)
model = model.to(MODEL_DEVICE)
def main(args):
config = model_config(args.model_name)
model = AlphaFold(config.model)
model = model.eval()
import_jax_weights_(model, args.param_path)
model = model.to(args.device)
with open(FEAT_PATH, "rb") as f:
with open(FEAT_PATH, "rb") as f:
batch = pickle.load(f)
with torch.no_grad():
batch = {k:torch.as_tensor(v, device=MODEL_DEVICE) for k,v in batch.items()}
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.device)
for k,v in batch.items()
}
longs = [
"aatype",
......@@ -69,43 +70,80 @@ with torch.no_grad():
batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous()
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
make_contig = lambda t: t.contiguous()
batch = tensor_tree_map(make_contig, batch)
t = time.time()
out = model(batch)
print(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
)
unrelaxed_protein = protein.from_prediction(
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
)
)
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation(
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
)
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB.
output_dir = '.'
relaxed_output_path = os.path.join(output_dir, f'relaxed_{MODEL_NAME}.pdb')
with open(relaxed_output_path, 'w') as f:
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'relaxed_{args.model_name}.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--model_name", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction"""
)
parser.add_argument(
"--param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
)
args = parser.parse_args()
if(args.param_path is None):
args.param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
)
main(args)
#!/bin/bash
FLAGS=""
while getopts ":v" option; do
case $option in
v)
FLAGS=$(echo "-v $FLAGS" | xargs) # strip whitespace
;;
*)
echo "Invalid option: ${option}"
;;
esac
done
python3 -m unittest $FLAGS "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment