Commit 54317fe4 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add pTM config options, spruce up pretrained script, add test script

parent 3fb44548
...@@ -2,38 +2,72 @@ import copy ...@@ -2,38 +2,72 @@ import copy
import ml_collections as mlc import ml_collections as mlc
def model_config(name): def model_config(name, train=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
if(name == "model_3"): if(name == "model_1"):
pass
elif(name == "model_2"):
pass
elif(name == "model_3"):
c.model.template.enabled = False c.model.template.enabled = False
elif(name == "model_4"): elif(name == "model_4"):
c.model.template.enabled = False c.model.template.enabled = False
elif(name == "model_5"): elif(name == "model_5"):
c.model.template.enabled = False c.model.template.enabled = False
elif(name == "model_1_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_2_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_3_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_4_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
elif(name == "model_5_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
return c if(train):
c.globals.model.blocks_per_ckpt = 1
c.globals.chunk_size = None
return c
c_z = mlc.FieldReference(128)
c_m = mlc.FieldReference(256)
c_t = mlc.FieldReference(64)
c_e = mlc.FieldReference(64)
c_s = mlc.FieldReference(384)
blocks_per_ckpt = mlc.FieldReference(1, field_type=int)
chunk_size = mlc.FieldReference(None, field_type=int)
aux_distogram_bins = mlc.FieldReference(64)
eps = 1e-8 c_z = mlc.FieldReference(128, field_type=int)
inf = 1e8 c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
inf = mlc.FieldReference(1e8, field_type=float)
config = mlc.ConfigDict({ config = mlc.ConfigDict({
"model": { # Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z, "c_z": c_z,
"c_m": c_m, "c_m": c_m,
"c_t": c_t, "c_t": c_t,
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"no_cycles": 2,#4, "eps": eps,
"inf": inf,
},
"model": {
"no_cycles": 4,
"_mask_trans": False, "_mask_trans": False,
"input_embedder": { "input_embedder": {
"tf_dim": 22, "tf_dim": 22,
...@@ -147,7 +181,7 @@ config = mlc.ConfigDict({ ...@@ -147,7 +181,7 @@ config = mlc.ConfigDict({
"no_qk_points": 4, "no_qk_points": 4,
"no_v_points": 8, "no_v_points": 8,
"dropout_rate": 0.1, "dropout_rate": 0.1,
"no_blocks": 2,#8, "no_blocks": 8,
"no_transition_layers": 1, "no_transition_layers": 1,
"no_resnet_blocks": 2, "no_resnet_blocks": 2,
"no_angles": 7, "no_angles": 7,
...@@ -168,7 +202,7 @@ config = mlc.ConfigDict({ ...@@ -168,7 +202,7 @@ config = mlc.ConfigDict({
"tm": { "tm": {
"c_z": c_z, "c_z": c_z,
"no_bins": aux_distogram_bins, "no_bins": aux_distogram_bins,
"enabled": True, "enabled": False,
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_m": c_m,
...@@ -245,7 +279,7 @@ config = mlc.ConfigDict({ ...@@ -245,7 +279,7 @@ config = mlc.ConfigDict({
"min_resolution": 0.1, "min_resolution": 0.1,
"max_resolution": 3.0, "max_resolution": 3.0,
"eps": eps,#1e-8, "eps": eps,#1e-8,
"weight": 1.0, "weight": 0.,
}, },
"eps": eps, "eps": eps,
}, },
......
...@@ -202,12 +202,12 @@ class AlphaFold(nn.Module): ...@@ -202,12 +202,12 @@ class AlphaFold(nn.Module):
if(None in [m_1_prev, z_prev, x_prev]): if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m), (*batch_dims, n, self.config.input_embedder.c_m),
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = z.new_zeros( z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z), (*batch_dims, n, n, self.config.input_embedder.c_z),
) )
# [*, N, 3] # [*, N, 3]
......
...@@ -13,22 +13,23 @@ ...@@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import math
import pickle
import os import os
# A hack to get OpenMM and PyTorch to peacefully coexist # A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL" os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import math
import pickle
import time import time
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from config import model_config from config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.np import residue_constants, protein from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax import openfold.np.relax.relax as relax
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
...@@ -38,23 +39,23 @@ from openfold.utils.tensor_utils import ( ...@@ -38,23 +39,23 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
MODEL_NAME = "model_1"
MODEL_DEVICE = "cuda:4"
PARAM_PATH = "openfold/resources/params/params_model_1.npz"
FEAT_PATH = "tests/test_data/sample_feats.pickle" FEAT_PATH = "tests/test_data/sample_feats.pickle"
config = model_config(MODEL_NAME) def main(args):
model = AlphaFold(config.model) config = model_config(args.model_name)
model = model.eval() model = AlphaFold(config.model)
import_jax_weights_(model, PARAM_PATH) model = model.eval()
model = model.to(MODEL_DEVICE) import_jax_weights_(model, args.param_path)
model = model.to(args.device)
with open(FEAT_PATH, "rb") as f: with open(FEAT_PATH, "rb") as f:
batch = pickle.load(f) batch = pickle.load(f)
with torch.no_grad(): with torch.no_grad():
batch = {k:torch.as_tensor(v, device=MODEL_DEVICE) for k,v in batch.items()} batch = {
k:torch.as_tensor(v, device=args.device)
for k,v in batch.items()
}
longs = [ longs = [
"aatype", "aatype",
...@@ -69,43 +70,80 @@ with torch.no_grad(): ...@@ -69,43 +70,80 @@ with torch.no_grad():
batch[l] = batch[l].long() batch[l] = batch[l].long()
# Move the recycling dimension to the end # Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0).contiguous() move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch) batch = tensor_tree_map(move_dim, batch)
make_contig = lambda t: t.contiguous()
batch = tensor_tree_map(make_contig, batch)
t = time.time() t = time.time()
out = model(batch) out = model(batch)
print(f"Inference time: {time.time() - t}") print(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"] plddt = out["plddt"]
mean_plddt = np.mean(plddt) mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat( plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1 plddt[..., None], residue_constants.atom_type_num, axis=-1
) )
unrelaxed_protein = protein.from_prediction( unrelaxed_protein = protein.from_prediction(
features=batch, features=batch,
result=out, result=out,
b_factors=plddt_b_factors b_factors=plddt_b_factors
) )
os.environ["CUDA_VISIBLE_DEVICES"] = "7" os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
**config.relax **config.relax
) )
# Relax the prediction. # Relax the prediction.
t = time.time() t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.time() - t}") print(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB. # Save the relaxed PDB.
output_dir = '.' relaxed_output_path = os.path.join(
relaxed_output_path = os.path.join(output_dir, f'relaxed_{MODEL_NAME}.pdb') args.output_dir, f'relaxed_{args.model_name}.pdb'
with open(relaxed_output_path, 'w') as f: )
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str) f.write(relaxed_pdb_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
parser.add_argument(
"--model_name", type=str, default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction"""
)
parser.add_argument(
"--param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
)
args = parser.parse_args()
if(args.param_path is None):
args.param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
)
main(args)
#!/bin/bash
FLAGS=""
while getopts ":v" option; do
case $option in
v)
FLAGS=$(echo "-v $FLAGS" | xargs) # strip whitespace
;;
*)
echo "Invalid option: ${option}"
;;
esac
done
python3 -m unittest $FLAGS "$@" || \
echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment