Commit 54317fe4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add pTM config options, spruce up pretrained script, add test script

parent 3fb44548
...@@ -2,38 +2,72 @@ import copy ...@@ -2,38 +2,72 @@ import copy
import ml_collections as mlc import ml_collections as mlc
def model_config(name): def model_config(name, train=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
if(name == "model_3"): if(name == "model_1"):
pass
elif(name == "model_2"):
pass
elif(name == "model_3"):
c.model.template.enabled = False c.model.template.enabled = False
elif(name == "model_4"): elif(name == "model_4"):
c.model.template.enabled = False c.model.template.enabled = False
elif(name == "model_5"): elif(name == "model_5"):
c.model.template.enabled = False c.model.template.enabled = False
elif(name == "model_1_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_2_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_3_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_4_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_5_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
if(train):
c.globals.model.blocks_per_ckpt = 1
c.globals.chunk_size = None
return c return c
c_z = mlc.FieldReference(128) c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256) c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64) c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64) c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384) c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(1, field_type=int) blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
eps = 1e-8 inf = mlc.FieldReference(1e8, field_type=float)
inf = 1e8
config = mlc.ConfigDict({ config = mlc.ConfigDict({
"model": { # Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"no_cycles": 2,#4, "eps": eps,
"inf": inf,
},
"model": {
"no_cycles": 4,
"_mask_trans": False, "_mask_trans": False,
"input_embedder": { "input_embedder": {
"tf_dim": 22, "tf_dim": 22,
...@@ -147,7 +181,7 @@ config = mlc.ConfigDict({ ...@@ -147,7 +181,7 @@ config = mlc.ConfigDict({
"no_qk_points": 4, "no_qk_points": 4,
"no_v_points": 8, "no_v_points": 8,
"dropout_rate": 0.1, "dropout_rate": 0.1,
"no_blocks": 2,#8, "no_blocks": 8,
"no_transition_layers": 1, "no_transition_layers": 1,
"no_resnet_blocks": 2, "no_resnet_blocks": 2,
"no_angles": 7, "no_angles": 7,
...@@ -168,7 +202,7 @@ config = mlc.ConfigDict({ ...@@ -168,7 +202,7 @@ config = mlc.ConfigDict({
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": True, "enabled": False,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
...@@ -245,7 +279,7 @@ config = mlc.ConfigDict({ ...@@ -245,7 +279,7 @@ config = mlc.ConfigDict({
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps,#1e-8, "eps": eps,#1e-8,
"weight": 1.0, "weight": 0.,
}, },
"eps": eps, "eps": eps,
}, },
......
...@@ -202,12 +202,12 @@ class AlphaFold(nn.Module): ...@@ -202,12 +202,12 @@ class AlphaFold(nn.Module):
if(None in [m_1_prev, z_prev, x_prev]): if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m), (*batch_dims, n, self.config.input_embedder.c_m),
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z.new_zeros( z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z), (*batch_dims, n, n, self.config.input_embedder.c_z),
) )
# [*, N, 3] # [*, N, 3]
......
...@@ -13,22 +13,23 @@ ...@@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import math
import pickle
import os import os
# A hack to get OpenMM and PyTorch to peacefully coexist # A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL" os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import math
import pickle
import time import time
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from config import model_config from config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
...@@ -38,74 +39,111 @@ from openfold.utils.tensor_utils import ( ...@@ -38,74 +39,111 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
MODEL_NAME = "model_1"
MODEL_DEVICE = "cuda:4"
PARAM_PATH = "openfold/resources/params/params_model_1.npz"
FEAT_PATH = "tests/test_data/sample_feats.pickle" FEAT_PATH = "tests/test_data/sample_feats.pickle"
config = model_config(MODEL_NAME) def main(args):
model = AlphaFold(config.model) config = model_config(args.model_name)
model = model.eval() model = AlphaFold(config.model)
import_jax_weights_(model, PARAM_PATH) model = model.eval()
model = model.to(MODEL_DEVICE) import_jax_weights_(model, args.param_path)
model = model.to(args.device)
with open(FEAT_PATH, "rb") as f:
batch = pickle.load(f)
with torch.no_grad():
batch = {k:torch.as_tensor(v, device=MODEL_DEVICE) for k,v in batch.items()}
longs = [ with open(FEAT_PATH, "rb") as f:
"aatype", batch = pickle.load(f)
"template_aatype",
"extra_msa",
"residx_atom37_to_atom14",
"residx_atom14_to_atom37",
"true_msa",
"residue_index",
]
for l in longs:
batch[l] = batch[l].long()
# Move the recycling dimension to the end with torch.no_grad():
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous() batch = {
batch = tensor_tree_map(move_dim, batch) k:torch.as_tensor(v, device=args.device)
for k,v in batch.items()
}
longs = [
"aatype",
"template_aatype",
"extra_msa",
"residx_atom37_to_atom14",
"residx_atom14_to_atom37",
"true_msa",
"residue_index",
]
for l in longs:
batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
make_contig = lambda t: t.contiguous()
batch = tensor_tree_map(make_contig, batch)
t = time.time()
out = model(batch)
print(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
)
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
# Relax the prediction.
t = time.time() t = time.time()
out = model(batch) relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Inference time: {time.time() - t}") print(f"Relaxation time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore # Save the relaxed PDB.
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) relaxed_output_path = os.path.join(
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) args.output_dir, f'relaxed_{args.model_name}.pdb'
)
plddt = out["plddt"] with open(relaxed_output_path, 'w') as f:
mean_plddt = np.mean(plddt) f.write(relaxed_pdb_str)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1 if __name__ == "__main__":
) parser = argparse.ArgumentParser()
parser.add_argument(
unrelaxed_protein = protein.from_prediction( "--device", type=str, default="cpu",
features=batch, help="""Name of the device on which to run the model. Any valid torch
result=out, device name is accepted (e.g. "cpu", "cuda:0")"""
b_factors=plddt_b_factors )
) parser.add_argument(
"--model_name", type=str, default="model_1",
os.environ["CUDA_VISIBLE_DEVICES"] = "7" help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
amber_relaxer = relax.AmberRelaxation( )
**config.relax parser.add_argument(
) "--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction"""
# Relax the prediction. )
t = time.time() parser.add_argument(
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) "--param_path", type=str, default=None,
print(f"Relaxation time: {time.time() - t}") help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
# Save the relaxed PDB. openfold/resources/params"""
output_dir = '.' )
relaxed_output_path = os.path.join(output_dir, f'relaxed_{MODEL_NAME}.pdb')
with open(relaxed_output_path, 'w') as f: args = parser.parse_args()
f.write(relaxed_pdb_str)
if(args.param_path is None):
args.param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
)
main(args)
#!/bin/bash
FLAGS=""
while getopts ":v" option; do
case $option in
v)
FLAGS=$(echo "-v $FLAGS" | xargs) # strip whitespace
;;
*)
echo "Invalid option: ${option}"
;;
esac
done
python3 -m unittest $FLAGS "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment