Commit f4043e1c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add AlphaFold-Gap inference

parent 89f05497
......@@ -17,13 +17,15 @@ def model_config(name, train=False, low_prec=False):
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
......@@ -36,17 +38,20 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = True
elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False
elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
......@@ -61,12 +66,14 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120
c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
......@@ -184,7 +191,6 @@ config = mlc.ConfigDict(
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
......@@ -223,6 +229,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
......@@ -235,6 +242,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
......@@ -247,6 +255,7 @@ config = mlc.ConfigDict(
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
......@@ -262,7 +271,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
"num_workers": 8,
},
},
},
......
......@@ -65,6 +65,47 @@ def make_template_features(
return template_features
def unify_template_features(
template_feature_list: Sequence[FeatureDict]
) -> FeatureDict:
out_dicts = []
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
for i, fd in enumerate(template_feature_list):
out_dict = {}
n_templates, n_res = fd["template_aatype"].shape[:2]
for k,v in fd.items():
seq_keys = [
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
]
if(k in seq_keys):
new_shape = list(v.shape)
assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
offset = sum(seq_lens[:i])
new_array[:, offset:offset + seq_lens[i]] = v
out_dict[k] = new_array
else:
out_dict[k] = v
chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices
out_dicts.append(out_dict)
out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
}
return out_dict
def make_sequence_features(
sequence: str, description: str, num_res: int
) -> FeatureDict:
......@@ -423,7 +464,6 @@ class DataPipeline:
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
......@@ -506,14 +546,12 @@ class DataPipeline:
return all_hits
def _process_msa_feats(
self,
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
_alignment_index: Optional[str] = None,
):
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
......@@ -531,6 +569,17 @@ class DataPipeline:
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
return msas, deletion_matrices
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, _alignment_index
)
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
......@@ -685,3 +734,92 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features}
def process_multiseq_fasta(self,
fasta_path: str,
super_alignment_dir: str,
ri_gap: int = 200,
) -> FeatureDict:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter. No templates.
"""
with open(fasta_path, 'r') as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed
input_descs = [i.split()[0] for i in input_descs]
# Stitch all of the sequences together
input_sequence = ''.join(input_seqs)
input_description = '-'.join(input_descs)
num_res = len(input_sequence)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
seq_lens = [len(s) for s in input_seqs]
total_offset = 0
for sl in seq_lens:
total_offset += sl
sequence_features["residue_index"][total_offset:] += ri_gap
msa_list = []
deletion_mat_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
final_msa = []
final_deletion_mat = []
msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
msas = [
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
]
deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats
]
assert(len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats)
msa_features = make_msa_features(
msas=final_msa,
deletion_matrices=final_deletion_mat,
)
template_feature_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, _alignment_index=None)
template_features = make_template_features(
seq,
hits,
self.template_featurizer,
)
template_feature_list.append(template_features)
template_features = unify_template_features(template_feature_list)
return {
**sequence_features,
**msa_features,
**template_features,
}
......@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
......@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
......
......@@ -16,8 +16,9 @@
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
from typing import Any, Sequence, Mapping, Optional
import re
import string
from openfold.np import residue_constants
from Bio.PDB import PDBParser
......@@ -52,6 +53,19 @@ class Protein:
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index: Optional[np.ndarray] = None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark: Optional[str] = None
# Templates used to generate this protein (prediction-only)
parents: Optional[Sequence[str]] = None
# Chain corresponding to each parent
parents_chain_index: Optional[Sequence[int]] = None
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
......@@ -188,6 +202,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
remark = prot.remark
if(remark is not None):
pdb_headers.append(f"REMARK {remark}")
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if(parents_chain_index is not None):
parents = [
p for i, p in zip(parents_chain_index, parents) if i == chain_id
]
if(parents is None or len(parents) == 0):
parents = ["N/A"]
pdb_headers.append(f"PARENT {' '.join(parents)}")
return pdb_headers
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
......@@ -208,15 +244,21 @@ def to_pdb(prot: Protein) -> str:
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
chain_index = prot.chain_index
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
pdb_lines.append("MODEL 1")
headers = get_pdb_headers(prot)
if(len(headers) > 0):
pdb_lines.extend(headers)
n = aatype.shape[0]
atom_index = 1
chain_id = "A"
prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites.
for i in range(aatype.shape[0]):
for i in range(n):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
......@@ -233,10 +275,15 @@ def to_pdb(prot: Protein) -> str:
0
] # Protein supports only C, N, O, S, this works.
charge = ""
chain_tag = "A"
if(chain_index is not None):
chain_tag = chain_tags[chain_index[i]]
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_id:>1}"
f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
......@@ -245,14 +292,27 @@ def to_pdb(prot: Protein) -> str:
pdb_lines.append(atom_line)
atom_index += 1
should_terminate = (i == n - 1)
if(chain_index is not None):
if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
should_terminate = True
prev_chain_index = chain_index[i + 1]
if(should_terminate):
# Close the chain.
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} "
f"{chain_id:>1}{residue_index[-1]:>4}"
f"{chain_end:<6}{atom_index:>5} "
f"{res_1to3(aatype[i]):>3} "
f"{chain_tag:>1}{residue_index[i]:>4}"
)
pdb_lines.append(chain_termination_line)
pdb_lines.append("ENDMDL")
atom_index += 1
if(i != n - 1):
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("END")
pdb_lines.append("")
......@@ -279,6 +339,10 @@ def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
chain_index: Optional[np.ndarray] = None,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None
) -> Protein:
"""Assembles a protein from a prediction.
......@@ -286,7 +350,9 @@ def from_prediction(
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns:
A protein instance.
"""
......@@ -299,4 +365,8 @@ def from_prediction(
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
chain_index=chain_index,
remark=remark,
parents=parents,
parents_chain_index=parents_chain_index,
)
......@@ -192,6 +192,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
return pdb_string
......
......@@ -87,4 +87,9 @@ class AmberRelaxation(object):
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
min_pdb = '\n'.join(['\n'.join(headers), min_pdb])
return min_pdb, debug_data, violations
......@@ -21,6 +21,9 @@ import numpy as np
import os
import pickle
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
import random
import sys
import time
......@@ -42,12 +45,160 @@ from openfold.utils.tensor_utils import (
from scripts.utils import add_data_args
def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
logging.info(f"Generating alignments for {tag}...")
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
fasta_path, local_alignment_dir
)
# Remove temporary FASTA file
os.remove(tmp_fasta_path)
def run_model(model, batch, tag, args):
logging.info("Executing model...")
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
# Disable templates if there aren't any in the batch
model.config.template.enabled = any([
"template_" in k for k in batch
])
logging.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
return out
def prep_output(out, batch, feature_dict, feature_processor, args):
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if(feature_processor.config.common.use_templates):
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if("template_chain_index" in feature_dict):
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.model_name}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - np.arange(ri.shape[0])) / args.multimer_ri_gap
chain_index = chain_index.astype(np.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if(c != cur_chain):
cur_chain = c
prev_chain_max = i + cur_chain * args.multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def main(args):
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
# Prep the model
config = model_config(args.model_name)
model = AlphaFold(config)
model = model.eval()
import_jax_weights_(model, args.param_path, version=args.model_name)
#script_preset_(model)
if(args.jax_param_path):
import_jax_weights_(
model, args.jax_param_path, version=args.model_name
)
elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)):
checkpoint_basename = os.path.splitext(
os.path.basename(
os.path.normpath(args.openfold_checkpoint_path)
)
)[0]
ckpt_path = os.path.join(
args.output_dir,
checkpoint_basename + ".pt",
)
if(not os.path.isfile(ckpt_path)):
convert_zero_checkpoint_to_fp32_state_dict(
args.openfold_checkpoint_path,
ckpt_path,
)
else:
ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer(
......@@ -77,6 +228,9 @@ def main(args):
else:
alignment_dir = args.use_precomputed_alignments
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
......@@ -88,81 +242,59 @@ def main(args):
][1:]
tags, seqs = lines[::2], lines[1::2]
assert len(seqs) == 1, "Input FASTAs may only contain one sequence"
tag, seq = tags[0], seqs[0]
tags = [t.split()[0] for t in tags]
assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
precompute_alignments(tags, seqs, alignment_dir, args)
fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(fasta_path, "w") as fp:
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
if(len(seqs) == 1):
seq = seqs[0]
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
logging.info("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
)
alignment_runner.run(
fasta_path, local_alignment_dir
else:
with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
)
# Remove temporary FASTA file
os.remove(fasta_path)
os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict',
)
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
t = time.perf_counter()
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
out = run_model(model, batch, tag, args)
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
unrelaxed_protein = prep_output(
out, batch, feature_dict, feature_processor, args
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
)
output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb'
prediction_dir, f'{output_name}_unrelaxed.pdb'
)
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
with open(unrelaxed_output_path, 'w') as fp:
fp.write(protein.to_pdb(unrelaxed_protein))
if(not args.skip_relaxation):
amber_relaxer = relax.AmberRelaxation(
......@@ -182,14 +314,14 @@ def main(args):
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb'
prediction_dir, f'{output_name}_relaxed.pdb'
)
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
with open(relaxed_output_path, 'w') as fp:
fp.write(relaxed_pdb_str)
if(args.save_outputs):
output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl'
args.output_dir, f'{output_name}_output_dict.pkl'
)
with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
......@@ -224,10 +356,15 @@ if __name__ == "__main__":
model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
)
parser.add_argument(
"--param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params"""
"--jax_param_path", type=str, default=None,
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
)
parser.add_argument(
"--save_outputs", action="store_true", default=False,
......@@ -241,17 +378,25 @@ if __name__ == "__main__":
"--preset", type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs')
)
parser.add_argument(
"--output_postfix", type=str, default=None,
help="""Postfix for output prediction filenames"""
)
parser.add_argument(
"--data_random_seed", type=str, default=None
)
parser.add_argument(
"--skip_relaxation", action="store_true", default=False,
)
parser.add_argument(
"--multimer_ri_gap", type=int, default=200,
help="""Residue index offset between multiple sequences, if provided"""
)
add_data_args(parser)
args = parser.parse_args()
if(args.param_path is None):
args.param_path = os.path.join(
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.jax_param_path = os.path.join(
"openfold", "resources", "params",
"params_" + args.model_name + ".npz"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment