Unverified Commit eecadd36 authored by shenggan's avatar shenggan Committed by GitHub
Browse files

Merge pull request #14 from hpcaitech/inference_pipeline

add inference pipeline from openfold/alphafold
parents b3b3b445 42427a1c
......@@ -41,10 +41,10 @@ python setup.py install
## Usage
You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model import Evoformer`:
You can use `Evoformer` as `nn.Module` in your project after `from fastfold.model.fastnn import Evoformer`:
```python
from fastfold.model import Evoformer
from fastfold.model.fastnn import Evoformer
evoformer_layer = Evoformer()
```
......@@ -58,15 +58,15 @@ init_dap(args.dap_size)
### Inference
You can use FastFold alongwith OpenFold with `inject_openfold`. This will replace the evoformer in OpenFold with the high performance evoformer from FastFold.
You can use FastFold alongwith OpenFold with `inject_fastnn`. This will replace the evoformer in OpenFold with the high performance evoformer from FastFold.
```python
from fastfold.utils import inject_openfold
from fastfold.utils import inject_fastnn
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_openfold(model)
model = inject_fastnn(model)
```
For Dynamic Axial Parallelism, you can refer to `./inference.py`. Here is an example of 2 GPUs parallel inference:
......
......@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from fastfold.distributed import init_dap
from fastfold.model import Evoformer
from fastfold.model.fastnn import Evoformer
def main():
......
name: fastfold
channels:
- conda-forge
- bioconda
- pytorch
dependencies:
- pip:
- biopython==1.79
- dm-tree==0.1.6
- ml-collections==0.1.0
- numpy==1.21.2
- PyYAML==5.4.1
- requests==2.26.0
- scipy==1.7.1
- tqdm==4.62.2
- typing-extensions==3.10.0.2
- einops
- colossalai
- pytorch::pytorch=1.11.0
- conda-forge::python=3.8
- conda-forge::setuptools=59.5.0
- conda-forge::pip
- conda-forge::openmm=7.5.1
- conda-forge::pdbfixer
- bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
import re
from fastfold.common import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If None, then the pdb file must contain a single chain (which
will be parsed). If chain_id is specified (e.g. A), then only that chain
is parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(f"Only single model PDBs are supported. Found {len(models)} models.")
model = models[0]
if chain_id is not None:
chain = model[chain_id]
else:
chains = list(model.get_chains())
if len(chains) != 1:
raise ValueError("Only single chain PDBs are supported when chain_id not specified. "
f"Found {len(chains)} chains.")
else:
chain = chains[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
b_factors = []
for res in chain:
if res.id[2] != " ":
raise ValueError(f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported.")
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(res_shortname,
residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
b_factors.append(res_b_factors)
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
b_factors=np.array(b_factors),
)
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)'
tags = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if ("[PRIMARY]" == g[0]):
seq = g[1][0].strip()
for i in range(len(seq)):
if (seq[i] not in residue_constants.restypes):
seq[i] = 'X'
aatype = np.array([
residue_constants.restype_order.get(res_symbol, residue_constants.restype_num)
for res_symbol in seq
])
elif ("[TERTIARY]" == g[0]):
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros(
(len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (np.transpose(
tertiary_np[:, i::3]))
atom_positions *= PICO_TO_ANGSTROM
elif ("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros((
len(mask),
residue_constants.atom_type_num,
)).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
pdb_lines.append("MODEL 1")
atom_index = 1
chain_id = "A"
# Add all atom sites.
for i in range(aatype.shape[0]):
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i],
b_factors[i]):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_id:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}")
pdb_lines.append(atom_line)
atom_index += 1
# Close the chain.
chain_end = "TER"
chain_termination_line = (f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} "
f"{chain_id:>1}{residue_index[-1]:>4}")
pdb_lines.append(chain_termination_line)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
pdb_lines.append("")
return "\n".join(pdb_lines)
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
b_factors=b_factors,
)
\ No newline at end of file
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import ml_collections as mlc
def set_inf(c, inf):
for k, v in c.items():
if isinstance(v, mlc.ConfigDict):
set_inf(v, inf)
elif k == "inf":
c[k] = inf
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
if name == "initial_training":
# AF2 Suppl. Table 4, "initial training" setting
pass
elif name == "finetuning":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.common.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.loss.violation.weight = 1.
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_2":
# AF2 Suppl. Table 5, Model 1.1.2
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
elif name == "model_3":
# AF2 Suppl. Table 5, Model 1.2.1
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_4":
# AF2 Suppl. Table 5, Model 1.2.2
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
elif name == "model_5":
# AF2 Suppl. Table 5, Model 1.2.3
c.model.template.enabled = False
elif name == "model_1_ptm":
c.data.common.max_extra_msa = 5120
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_2_ptm":
c.data.common.reduce_max_clusters_by_max_templates = True
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_3_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_4_ptm":
c.data.common.max_extra_msa = 5120
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_5_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
return c
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
config = mlc.ConfigDict(
{
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
"alt_chi_angles": [NUM_RES, None],
"atom14_alt_gt_exists": [NUM_RES, None],
"atom14_alt_gt_positions": [NUM_RES, None, None],
"atom14_atom_exists": [NUM_RES, None],
"atom14_atom_is_ambiguous": [NUM_RES, None],
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_rigid_mask": [NUM_RES],
"backbone_rigid_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles_sin_cos": [NUM_RES, None, None],
"chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_row_mask": [NUM_EXTRA_SEQ],
"is_distillation": [],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ],
"no_recycling_iters": [],
"pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
"rigidgroups_group_exists": [NUM_RES, None],
"rigidgroups_group_is_ambiguous": [NUM_RES, None],
"rigidgroups_gt_exists": [NUM_RES, None],
"rigidgroups_gt_frames": [NUM_RES, None, None, None],
"seq_length": [],
"seq_mask": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_rigid_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
"template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
"template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
"template_sum_probs": [NUM_TEMPLATES, None],
"template_torsion_angles_mask": [
NUM_TEMPLATES, NUM_RES, None,
],
"template_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"max_recycling_iters": 3,
"msa_cluster_features": True,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
],
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"supervised": {
"clamp_prob": 0.9,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
"resolution",
"use_clamped_fape",
"is_distillation",
],
},
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"crop": False,
"crop_size": None,
"supervised": True,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"crop": True,
"crop_size": 256,
"supervised": True,
"clamp_prob": 0.9,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
},
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 16,
},
},
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
},
"model": {
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"recycling_embedder": {
"c_z": c_z,
"c_m": c_m,
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
},
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": eps, # 1e-6,
"weight": 0.0,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
"enabled": tm_enabled,
},
"eps": eps,
},
"ema": {"decay": 0.999},
}
)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General-purpose errors used throughout the data pipeline"""
class Error(Exception):
"""Base class for exceptions."""
class MultipleChainsError(Error):
"""An error indicating that multiple chains were found for a given ID."""
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Mapping, Tuple, List, Optional, Dict, Sequence
import ml_collections
import numpy as np
import torch
from fastfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict
def make_data_config(
config: ml_collections.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if cfg[mode].supervised:
feature_names += cfg.supervised.supervised_features
return cfg, feature_names
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
mode: str,
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
).astype(np.float32)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
return {k: v for k, v in features.items()}
class FeaturePipeline:
def __init__(
self,
config: ml_collections.ConfigDict,
):
self.config = config
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHsearch from Python."""
import glob
import logging
import os
import subprocess
from typing import Sequence
from fastfold.data.tools import utils
class HHSearch:
"""Python wrapper of the HHsearch binary."""
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper.
Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
self.n_cpu = n_cpu
self.maxseq = maxseq
for database_path in self.databases:
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f:
f.write(a3m)
db_cmd = []
for db_path in self.databases:
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
"-i",
input_path,
"-o",
hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHsearch query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f:
hhr = f.read()
return hhr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment