Commit f4043e1c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add AlphaFold-Gap inference

parent 89f05497
...@@ -17,13 +17,15 @@ def model_config(name, train=False, low_prec=False): ...@@ -17,13 +17,15 @@ def model_config(name, train=False, low_prec=False):
pass pass
elif name == "finetuning": elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting # AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384 c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512 c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1. c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "model_1": elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1 # AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -36,17 +38,20 @@ def model_config(name, train=False, low_prec=False): ...@@ -36,17 +38,20 @@ def model_config(name, train=False, low_prec=False):
c.model.template.enabled = True c.model.template.enabled = True
elif name == "model_3": elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1 # AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_4": elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2 # AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_5": elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3 # AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False c.model.template.enabled = False
elif name == "model_1_ptm": elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True c.data.common.use_template_torsion_angles = True
...@@ -61,12 +66,14 @@ def model_config(name, train=False, low_prec=False): ...@@ -61,12 +66,14 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_3_ptm": elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif name == "model_4_ptm": elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120 c.data.train.max_extra_msa = 5120
c.data.predict.max_extra_msa = 5120
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
...@@ -184,7 +191,6 @@ config = mlc.ConfigDict( ...@@ -184,7 +191,6 @@ config = mlc.ConfigDict(
"same_prob": 0.1, "same_prob": 0.1,
"uniform_prob": 0.1, "uniform_prob": 0.1,
}, },
"max_extra_msa": 1024,
"max_recycling_iters": 3, "max_recycling_iters": 3,
"msa_cluster_features": True, "msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False, "reduce_msa_clusters_by_max_templates": False,
...@@ -223,6 +229,7 @@ config = mlc.ConfigDict( ...@@ -223,6 +229,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -235,6 +242,7 @@ config = mlc.ConfigDict( ...@@ -235,6 +242,7 @@ config = mlc.ConfigDict(
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"crop": False, "crop": False,
...@@ -247,6 +255,7 @@ config = mlc.ConfigDict( ...@@ -247,6 +255,7 @@ config = mlc.ConfigDict(
"subsample_templates": True, "subsample_templates": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_template_hits": 4, "max_template_hits": 4,
"max_templates": 4, "max_templates": 4,
"shuffle_top_k_prefiltered": 20, "shuffle_top_k_prefiltered": 20,
...@@ -262,7 +271,7 @@ config = mlc.ConfigDict( ...@@ -262,7 +271,7 @@ config = mlc.ConfigDict(
"use_small_bfd": False, "use_small_bfd": False,
"data_loaders": { "data_loaders": {
"batch_size": 1, "batch_size": 1,
"num_workers": 16, "num_workers": 8,
}, },
}, },
}, },
......
...@@ -65,6 +65,47 @@ def make_template_features( ...@@ -65,6 +65,47 @@ def make_template_features(
return template_features return template_features
def unify_template_features(
template_feature_list: Sequence[FeatureDict]
) -> FeatureDict:
out_dicts = []
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
for i, fd in enumerate(template_feature_list):
out_dict = {}
n_templates, n_res = fd["template_aatype"].shape[:2]
for k,v in fd.items():
seq_keys = [
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
]
if(k in seq_keys):
new_shape = list(v.shape)
assert(new_shape[1] == n_res)
new_shape[1] = sum(seq_lens)
new_array = np.zeros(new_shape, dtype=v.dtype)
if(k == "template_aatype"):
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
offset = sum(seq_lens[:i])
new_array[:, offset:offset + seq_lens[i]] = v
out_dict[k] = new_array
else:
out_dict[k] = v
chain_indices = np.array(n_templates * [i])
out_dict["template_chain_index"] = chain_indices
out_dicts.append(out_dict)
out_dict = {
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
}
return out_dict
def make_sequence_features( def make_sequence_features(
sequence: str, description: str, num_res: int sequence: str, description: str, num_res: int
) -> FeatureDict: ) -> FeatureDict:
...@@ -422,8 +463,7 @@ class DataPipeline: ...@@ -422,8 +463,7 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
_alignment_index: Optional[Any] = None, _alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
msa_data = {} msa_data = {}
if(_alignment_index is not None): if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb") fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
...@@ -506,14 +546,12 @@ class DataPipeline: ...@@ -506,14 +546,12 @@ class DataPipeline:
return all_hits return all_hits
def _process_msa_feats( def _get_msas(self,
self,
alignment_dir: str, alignment_dir: str,
input_sequence: Optional[str] = None, input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None _alignment_index: Optional[str] = None,
) -> Mapping[str, Any]: ):
msa_data = self._parse_msa_data(alignment_dir, _alignment_index) msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0): if(len(msa_data) == 0):
if(input_sequence is None): if(input_sequence is None):
raise ValueError( raise ValueError(
...@@ -531,6 +569,17 @@ class DataPipeline: ...@@ -531,6 +569,17 @@ class DataPipeline:
(v["msa"], v["deletion_matrix"]) for v in msa_data.values() (v["msa"], v["deletion_matrix"]) for v in msa_data.values()
]) ])
return msas, deletion_matrices
def _process_msa_feats(
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, _alignment_index
)
msa_features = make_msa_features( msa_features = make_msa_features(
msas=msas, msas=msas,
deletion_matrices=deletion_matrices, deletion_matrices=deletion_matrices,
...@@ -685,3 +734,92 @@ class DataPipeline: ...@@ -685,3 +734,92 @@ class DataPipeline:
return {**core_feats, **template_features, **msa_features} return {**core_feats, **template_features, **msa_features}
def process_multiseq_fasta(self,
fasta_path: str,
super_alignment_dir: str,
ri_gap: int = 200,
) -> FeatureDict:
"""
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
hack from Twitter. No templates.
"""
with open(fasta_path, 'r') as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
# No whitespace allowed
input_descs = [i.split()[0] for i in input_descs]
# Stitch all of the sequences together
input_sequence = ''.join(input_seqs)
input_description = '-'.join(input_descs)
num_res = len(input_sequence)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res,
)
seq_lens = [len(s) for s in input_seqs]
total_offset = 0
for sl in seq_lens:
total_offset += sl
sequence_features["residue_index"][total_offset:] += ri_gap
msa_list = []
deletion_mat_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
msas, deletion_mats = self._get_msas(
alignment_dir, seq, None
)
msa_list.append(msas)
deletion_mat_list.append(deletion_mats)
final_msa = []
final_deletion_mat = []
msa_it = enumerate(zip(msa_list, deletion_mat_list))
for i, (msas, deletion_mats) in msa_it:
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
msas = [
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
]
deletion_mats = [
[prec * [0] + dml + post * [0] for dml in deletion_mat]
for deletion_mat in deletion_mats
]
assert(len(msas[0][-1]) == len(input_sequence))
final_msa.extend(msas)
final_deletion_mat.extend(deletion_mats)
msa_features = make_msa_features(
msas=final_msa,
deletion_matrices=final_deletion_mat,
)
template_feature_list = []
for seq, desc in zip(input_seqs, input_descs):
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, _alignment_index=None)
template_features = make_template_features(
seq,
hits,
self.template_featurizer,
)
template_feature_list.append(template_features)
template_features = unify_template_features(template_feature_list)
return {
**sequence_features,
**msa_features,
**template_features,
}
...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -84,7 +84,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
pad_msa_clusters = mode_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None msa_seed = None
if(not common_cfg.resample_msa_in_recycling): if(not common_cfg.resample_msa_in_recycling):
...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -137,7 +137,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
data_transforms.make_fixed_size( data_transforms.make_fixed_size(
crop_feats, crop_feats,
pad_msa_clusters, pad_msa_clusters,
common_cfg.max_extra_msa, mode_cfg.max_extra_msa,
mode_cfg.crop_size, mode_cfg.crop_size,
mode_cfg.max_templates, mode_cfg.max_templates,
) )
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
"""Protein data type.""" """Protein data type."""
import dataclasses import dataclasses
import io import io
from typing import Any, Mapping, Optional from typing import Any, Sequence, Mapping, Optional
import re import re
import string
from openfold.np import residue_constants from openfold.np import residue_constants
from Bio.PDB import PDBParser from Bio.PDB import PDBParser
...@@ -52,6 +53,19 @@ class Protein: ...@@ -52,6 +53,19 @@ class Protein:
# value. # value.
b_factors: np.ndarray # [num_res, num_atom_type] b_factors: np.ndarray # [num_res, num_atom_type]
# Chain indices for multi-chain predictions
chain_index: Optional[np.ndarray] = None
# Optional remark about the protein. Included as a comment in output PDB
# files
remark: Optional[str] = None
# Templates used to generate this protein (prediction-only)
parents: Optional[Sequence[str]] = None
# Chain corresponding to each parent
parents_chain_index: Optional[Sequence[int]] = None
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object. """Takes a PDB string and constructs a Protein object.
...@@ -188,6 +202,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -188,6 +202,28 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
) )
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
pdb_headers = []
remark = prot.remark
if(remark is not None):
pdb_headers.append(f"REMARK {remark}")
parents = prot.parents
parents_chain_index = prot.parents_chain_index
if(parents_chain_index is not None):
parents = [
p for i, p in zip(parents_chain_index, parents) if i == chain_id
]
if(parents is None or len(parents) == 0):
parents = ["N/A"]
pdb_headers.append(f"PARENT {' '.join(parents)}")
return pdb_headers
def to_pdb(prot: Protein) -> str: def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string. """Converts a `Protein` instance to a PDB string.
...@@ -208,15 +244,21 @@ def to_pdb(prot: Protein) -> str: ...@@ -208,15 +244,21 @@ def to_pdb(prot: Protein) -> str:
atom_positions = prot.atom_positions atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32) residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors b_factors = prot.b_factors
chain_index = prot.chain_index
if np.any(aatype > residue_constants.restype_num): if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.") raise ValueError("Invalid aatypes.")
pdb_lines.append("MODEL 1") headers = get_pdb_headers(prot)
if(len(headers) > 0):
pdb_lines.extend(headers)
n = aatype.shape[0]
atom_index = 1 atom_index = 1
chain_id = "A" prev_chain_index = 0
chain_tags = string.ascii_uppercase
# Add all atom sites. # Add all atom sites.
for i in range(aatype.shape[0]): for i in range(n):
res_name_3 = res_1to3(aatype[i]) res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip( for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i] atom_types, atom_positions[i], atom_mask[i], b_factors[i]
...@@ -233,10 +275,15 @@ def to_pdb(prot: Protein) -> str: ...@@ -233,10 +275,15 @@ def to_pdb(prot: Protein) -> str:
0 0
] # Protein supports only C, N, O, S, this works. ] # Protein supports only C, N, O, S, this works.
charge = "" charge = ""
chain_tag = "A"
if(chain_index is not None):
chain_tag = chain_tags[chain_index[i]]
# PDB is a columnar format, every space matters here! # PDB is a columnar format, every space matters here!
atom_line = ( atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_id:>1}" f"{res_name_3:>3} {chain_tag:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} " f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} " f"{occupancy:>6.2f}{b_factor:>6.2f} "
...@@ -245,14 +292,27 @@ def to_pdb(prot: Protein) -> str: ...@@ -245,14 +292,27 @@ def to_pdb(prot: Protein) -> str:
pdb_lines.append(atom_line) pdb_lines.append(atom_line)
atom_index += 1 atom_index += 1
# Close the chain. should_terminate = (i == n - 1)
chain_end = "TER" if(chain_index is not None):
chain_termination_line = ( if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} " should_terminate = True
f"{chain_id:>1}{residue_index[-1]:>4}" prev_chain_index = chain_index[i + 1]
)
pdb_lines.append(chain_termination_line) if(should_terminate):
pdb_lines.append("ENDMDL") # Close the chain.
chain_end = "TER"
chain_termination_line = (
f"{chain_end:<6}{atom_index:>5} "
f"{res_1to3(aatype[i]):>3} "
f"{chain_tag:>1}{residue_index[i]:>4}"
)
pdb_lines.append(chain_termination_line)
atom_index += 1
if(i != n - 1):
# "prev" is a misnomer here. This happens at the beginning of
# each new chain.
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
pdb_lines.append("END") pdb_lines.append("END")
pdb_lines.append("") pdb_lines.append("")
...@@ -279,6 +339,10 @@ def from_prediction( ...@@ -279,6 +339,10 @@ def from_prediction(
features: FeatureDict, features: FeatureDict,
result: ModelOutput, result: ModelOutput,
b_factors: Optional[np.ndarray] = None, b_factors: Optional[np.ndarray] = None,
chain_index: Optional[np.ndarray] = None,
remark: Optional[str] = None,
parents: Optional[Sequence[str]] = None,
parents_chain_index: Optional[Sequence[int]] = None
) -> Protein: ) -> Protein:
"""Assembles a protein from a prediction. """Assembles a protein from a prediction.
...@@ -286,7 +350,9 @@ def from_prediction( ...@@ -286,7 +350,9 @@ def from_prediction(
features: Dictionary holding model inputs. features: Dictionary holding model inputs.
result: Dictionary holding model outputs. result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein. b_factors: (Optional) B-factors to use for the protein.
chain_index: (Optional) Chain indices for multi-chain predictions
remark: (Optional) Remark about the prediction
parents: (Optional) List of template names
Returns: Returns:
A protein instance. A protein instance.
""" """
...@@ -299,4 +365,8 @@ def from_prediction( ...@@ -299,4 +365,8 @@ def from_prediction(
atom_mask=result["final_atom_mask"], atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1, residue_index=features["residue_index"] + 1,
b_factors=b_factors, b_factors=b_factors,
chain_index=chain_index,
remark=remark,
parents=parents,
parents_chain_index=parents_chain_index,
) )
...@@ -192,6 +192,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True): ...@@ -192,6 +192,11 @@ def clean_protein(prot: protein.Protein, checks: bool = True):
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions()) pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
if checks: if checks:
_check_cleaned_atoms(pdb_string, prot_pdb_string) _check_cleaned_atoms(pdb_string, prot_pdb_string)
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
return pdb_string return pdb_string
......
...@@ -87,4 +87,9 @@ class AmberRelaxation(object): ...@@ -87,4 +87,9 @@ class AmberRelaxation(object):
violations = out["structural_violations"][ violations = out["structural_violations"][
"total_per_residue_violations_mask" "total_per_residue_violations_mask"
] ]
headers = protein.get_pdb_headers(prot)
if(len(headers) > 0):
min_pdb = '\n'.join(['\n'.join(headers), min_pdb])
return min_pdb, debug_data, violations return min_pdb, debug_data, violations
...@@ -21,6 +21,9 @@ import numpy as np ...@@ -21,6 +21,9 @@ import numpy as np
import os import os
import pickle import pickle
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict
)
import random import random
import sys import sys
import time import time
...@@ -42,12 +45,160 @@ from openfold.utils.tensor_utils import ( ...@@ -42,12 +45,160 @@ from openfold.utils.tensor_utils import (
from scripts.utils import add_data_args from scripts.utils import add_data_args
def precompute_alignments(tags, seqs, alignment_dir, args):
for tag, seq in zip(tags, seqs):
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join(alignment_dir, tag)
if(args.use_precomputed_alignments is None):
logging.info(f"Generating alignments for {tag}...")
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(
fasta_path, local_alignment_dir
)
# Remove temporary FASTA file
os.remove(tmp_fasta_path)
def run_model(model, batch, tag, args):
logging.info("Executing model...")
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
# Disable templates if there aren't any in the batch
model.config.template.enabled = any([
"template_" in k for k in batch
])
logging.info(f"Running inference for {tag}...")
t = time.perf_counter()
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
return out
def prep_output(out, batch, feature_dict, feature_processor, args):
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
# Prep protein metadata
template_domain_names = []
template_chain_index = None
if(feature_processor.config.common.use_templates):
template_domain_names = [
t.decode("utf-8") for t in feature_dict["template_domain_names"]
]
# This works because templates are not shuffled during inference
template_domain_names = template_domain_names[
:feature_processor.config.predict.max_templates
]
if("template_chain_index" in feature_dict):
template_chain_index = feature_dict["template_chain_index"]
template_chain_index = template_chain_index[
:feature_processor.config.predict.max_templates
]
no_recycling = feature_processor.config.common.max_recycling_iters
remark = ', '.join([
f"no_recycling={no_recycling}",
f"max_templates={feature_processor.config.predict.max_templates}",
f"config_preset={args.model_name}",
])
# For multi-chain FASTAs
ri = feature_dict["residue_index"]
chain_index = (ri - np.arange(ri.shape[0])) / args.multimer_ri_gap
chain_index = chain_index.astype(np.int64)
cur_chain = 0
prev_chain_max = 0
for i, c in enumerate(chain_index):
if(c != cur_chain):
cur_chain = c
prev_chain_max = i + cur_chain * args.multimer_ri_gap
batch["residue_index"][i] -= prev_chain_max
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors,
chain_index=chain_index,
remark=remark,
parents=template_domain_names,
parents_chain_index=template_chain_index,
)
return unrelaxed_protein
def main(args): def main(args):
# Create the output directory
os.makedirs(args.output_dir, exist_ok=True)
# Prep the model
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config) model = AlphaFold(config)
model = model.eval() model = model.eval()
import_jax_weights_(model, args.param_path, version=args.model_name)
#script_preset_(model) if(args.jax_param_path):
import_jax_weights_(
model, args.jax_param_path, version=args.model_name
)
elif(args.openfold_checkpoint_path):
if(os.path.isdir(args.openfold_checkpoint_path)):
checkpoint_basename = os.path.splitext(
os.path.basename(
os.path.normpath(args.openfold_checkpoint_path)
)
)[0]
ckpt_path = os.path.join(
args.output_dir,
checkpoint_basename + ".pt",
)
if(not os.path.isfile(ckpt_path)):
convert_zero_checkpoint_to_fp32_state_dict(
args.openfold_checkpoint_path,
ckpt_path,
)
else:
ckpt_path = args.openfold_checkpoint_path
d = torch.load(ckpt_path)
model.load_state_dict(d["ema"]["params"])
else:
raise ValueError(
"At least one of jax_param_path or openfold_checkpoint_path must "
"be specified."
)
model = model.to(args.model_device) model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
...@@ -77,6 +228,9 @@ def main(args): ...@@ -77,6 +228,9 @@ def main(args):
else: else:
alignment_dir = args.use_precomputed_alignments alignment_dir = args.use_precomputed_alignments
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in os.listdir(args.fasta_dir): for fasta_file in os.listdir(args.fasta_dir):
# Gather input sequences # Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp: with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
...@@ -88,81 +242,59 @@ def main(args): ...@@ -88,81 +242,59 @@ def main(args):
][1:] ][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
assert len(seqs) == 1, "Input FASTAs may only contain one sequence" tags = [t.split()[0] for t in tags]
tag, seq = tags[0], seqs[0] assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta") precompute_alignments(tags, seqs, alignment_dir, args)
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
logging.info("Generating features...") tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
local_alignment_dir = os.path.join(alignment_dir, tag) if(len(seqs) == 1):
if(args.use_precomputed_alignments is None): seq = seqs[0]
if not os.path.exists(local_alignment_dir): with open(tmp_fasta_path, "w") as fp:
os.makedirs(local_alignment_dir) fp.write(f">{tag}\n{seq}")
alignment_runner = data_pipeline.AlignmentRunner( local_alignment_dir = os.path.join(alignment_dir, tag)
jackhmmer_binary_path=args.jackhmmer_binary_path, feature_dict = data_processor.process_fasta(
hhblits_binary_path=args.hhblits_binary_path, fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
) )
alignment_runner.run( else:
fasta_path, local_alignment_dir with open(tmp_fasta_path, "w") as fp:
fp.write(
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
)
feature_dict = data_processor.process_multiseq_fasta(
fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
) )
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
# Remove temporary FASTA file # Remove temporary FASTA file
os.remove(fasta_path) os.remove(tmp_fasta_path)
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', feature_dict, mode='predict',
) )
logging.info("Executing model...")
batch = processed_feature_dict batch = processed_feature_dict
with torch.no_grad(): out = run_model(model, batch, tag, args)
batch = {
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
t = time.perf_counter()
out = model(batch)
logging.info(f"Inference time: {time.perf_counter() - t}")
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"] unrelaxed_protein = prep_output(
mean_plddt = np.mean(plddt) out, batch, feature_dict, feature_processor, args
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
) )
output_name = f'{tag}_{args.model_name}'
if(args.output_postfix is not None):
output_name = f'{output_name}_{args.output_postfix}'
# Save the unrelaxed PDB. # Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join( unrelaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_unrelaxed.pdb' prediction_dir, f'{output_name}_unrelaxed.pdb'
) )
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as fp:
f.write(protein.to_pdb(unrelaxed_protein)) fp.write(protein.to_pdb(unrelaxed_protein))
if(not args.skip_relaxation): if(not args.skip_relaxation):
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
...@@ -182,14 +314,14 @@ def main(args): ...@@ -182,14 +314,14 @@ def main(args):
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb' prediction_dir, f'{output_name}_relaxed.pdb'
) )
with open(relaxed_output_path, 'w') as f: with open(relaxed_output_path, 'w') as fp:
f.write(relaxed_pdb_str) fp.write(relaxed_pdb_str)
if(args.save_outputs): if(args.save_outputs):
output_dict_path = os.path.join( output_dict_path = os.path.join(
args.output_dir, f'{tag}_{args.model_name}_output_dict.pkl' args.output_dir, f'{output_name}_output_dict.pkl'
) )
with open(output_dict_path, "wb") as fp: with open(output_dict_path, "wb") as fp:
pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(out, fp, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -224,10 +356,15 @@ if __name__ == "__main__": ...@@ -224,10 +356,15 @@ if __name__ == "__main__":
model_{1-5}_ptm, as defined on the AlphaFold GitHub.""" model_{1-5}_ptm, as defined on the AlphaFold GitHub."""
) )
parser.add_argument( parser.add_argument(
"--param_path", type=str, default=None, "--jax_param_path", type=str, default=None,
help="""Path to model parameters. If None, parameters are selected help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
automatically according to the model name from is also None, parameters are selected automatically according to
openfold/resources/params""" the model name from openfold/resources/params"""
)
parser.add_argument(
"--openfold_checkpoint_path", type=str, default=None,
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
checkpoint directory or a .pt file"""
) )
parser.add_argument( parser.add_argument(
"--save_outputs", action="store_true", default=False, "--save_outputs", action="store_true", default=False,
...@@ -241,17 +378,25 @@ if __name__ == "__main__": ...@@ -241,17 +378,25 @@ if __name__ == "__main__":
"--preset", type=str, default='full_dbs', "--preset", type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs') choices=('reduced_dbs', 'full_dbs')
) )
parser.add_argument(
"--output_postfix", type=str, default=None,
help="""Postfix for output prediction filenames"""
)
parser.add_argument( parser.add_argument(
"--data_random_seed", type=str, default=None "--data_random_seed", type=str, default=None
) )
parser.add_argument( parser.add_argument(
"--skip_relaxation", action="store_true", default=False, "--skip_relaxation", action="store_true", default=False,
) )
parser.add_argument(
"--multimer_ri_gap", type=int, default=200,
help="""Residue index offset between multiple sequences, if provided"""
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
if(args.param_path is None): if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
args.param_path = os.path.join( args.jax_param_path = os.path.join(
"openfold", "resources", "params", "openfold", "resources", "params",
"params_" + args.model_name + ".npz" "params_" + args.model_name + ".npz"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment