Commit 13f8f163 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents a509a4c5 b5fa2ba3
Pipeline #235 failed with stages
in 0 seconds
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Mapping, Tuple, List, Optional, Dict, Sequence
import ml_collections
import numpy as np
import torch
from openfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
np_example: A dict of NumPy feature arrays.
features: A list of strings of feature names to be returned in the dataset.
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict
def make_data_config(
config: ml_collections.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if cfg[mode].supervised:
feature_names += cfg.supervised.supervised_features
return cfg, feature_names
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
mode: str,
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
).astype(np.float32)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()}
class FeaturePipeline:
def __init__(
self,
config: ml_collections.ConfigDict,
):
self.config = config
def process_features(
self,
raw_features: FeatureDict,
mode: str = "train",
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
from openfold.data import data_transforms
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes,
data_transforms.squeeze_features,
data_transforms.randomly_replace_msa_with_unknown(0.0),
data_transforms.make_seq_mask,
data_transforms.make_msa_mask,
data_transforms.make_hhblits_profile,
]
if common_cfg.use_templates:
transforms.extend(
[
data_transforms.fix_templates_aatype,
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta("template_"),
]
)
if common_cfg.use_template_torsion_angles:
transforms.extend(
[
data_transforms.atom37_to_torsion_angles("template_"),
]
)
transforms.extend(
[
data_transforms.make_atom14_masks,
]
)
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
if "max_distillation_msa_clusters" in mode_cfg:
transforms.append(
data_transforms.sample_msa_distillation(
mode_cfg.max_distillation_msa_clusters
)
)
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = mode_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms.sample_msa(
max_msa_clusters,
keep_extra=True,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
)
)
if common_cfg.msa_cluster_features:
transforms.append(data_transforms.nearest_neighbor_clusters())
transforms.append(data_transforms.summarize_clusters())
# Crop after creating the cluster profiles.
if max_extra_msa:
transforms.append(data_transforms.crop_extra_msa(max_extra_msa))
else:
transforms.append(data_transforms.delete_extra_msa)
transforms.append(data_transforms.make_msa_feat())
crop_feats = dict(common_cfg.feat)
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(
data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
seed=ensemble_seed + 1,
)
)
transforms.append(
data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
mode_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses the mmCIF file format."""
import collections
import dataclasses
import io
import json
import logging
import os
from typing import Any, Mapping, Optional, Sequence, Tuple
from Bio import PDB
from Bio.Data import SCOPData
import numpy as np
from openfold.data.errors import MultipleChainsError
import openfold.np.residue_constants as residue_constants
# Type aliases:
ChainId = str
PdbHeader = Mapping[str, Any]
PdbStructure = PDB.Structure.Structure
SeqRes = str
MmCIFDict = Mapping[str, Sequence[str]]
@dataclasses.dataclass(frozen=True)
class Monomer:
id: str
num: int
# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True)
class AtomSite:
residue_name: str
author_chain_id: str
mmcif_chain_id: str
author_seq_num: str
mmcif_seq_num: int
insertion_code: str
hetatm_atom: str
model_num: int
# Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True)
class ResiduePosition:
chain_id: str
residue_number: int
insertion_code: str
@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
position: Optional[ResiduePosition]
name: str
is_missing: bool
hetflag: str
@dataclasses.dataclass(frozen=True)
class MmcifObject:
"""Representation of a parsed mmCIF file.
Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition,
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PdbStructure
chain_to_seqres: Mapping[ChainId, SeqRes]
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any
@dataclasses.dataclass(frozen=True)
class ParsingResult:
"""Returned by the parse function.
Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(
prefix: str, parsed_info: MmCIFDict
) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols = []
data = []
for key, value in parsed_info.items():
if key.startswith(prefix):
cols.append(key)
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
"mmCIF error: Not all loops are the same length: %s" % cols
)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(
prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}
def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.
Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure("", handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]
header = _get_header(parsed_info)
# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ""): "No protein chains found in this file."}
)
seq_start_num = {
chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()
}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != "1":
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = " "
if atom.hetatm_atom == "HETATM":
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ("HOH", "WAT"):
hetflag = "W"
else:
hetflag = "H_" + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = " "
position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code,
)
seq_idx = (
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
)
current = seq_to_structure_mappings.get(
atom.author_chain_id, {}
)
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id,
is_missing=True,
hetflag=" ",
)
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X")
seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info,
)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, "")] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
return min(revision_dates)
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}
experiments = mmcif_loop_to_list("_exptl.", parsed_info)
header["structure_method"] = ",".join(
[experiment["_exptl.method"].lower() for experiment in experiments]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if "_pdbx_audit_revision_history.revision_date" in parsed_info:
header["release_date"] = get_release_date(parsed_info)
else:
logging.warning(
"Could not determine release_date: %s", parsed_info["_entry.id"]
)
header["resolution"] = 0.00
for res_key in (
"_refine.ls_d_res_high",
"_em_3d_reconstruction.resolution",
"_reflns.d_resolution_high",
):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution)
except ValueError:
logging.info(
"Invalid resolution format: %s", parsed_info[res_key]
)
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [
AtomSite(*site)
for site in zip( # pylint:disable=g-complex-comprehension
parsed_info["_atom_site.label_comp_id"],
parsed_info["_atom_site.auth_asym_id"],
parsed_info["_atom_site.label_asym_id"],
parsed_info["_atom_site.auth_seq_id"],
parsed_info["_atom_site.label_seq_id"],
parsed_info["_atom_site.pdbx_PDB_ins_code"],
parsed_info["_atom_site.group_PDB"],
parsed_info["_atom_site.pdbx_PDB_model_num"],
)
]
def _get_protein_chains(
*, parsed_info: Mapping[str, Any]
) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
parsed_info: _mmcif_dict produced by the Biopython parser.
Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
Monomer(
id=entity_poly_seq["_entity_poly_seq.mon_id"],
num=int(entity_poly_seq["_entity_poly_seq.num"]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym["_struct_asym.id"]
entity_id = struct_asym["_struct_asym.entity_id"]
entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains.
valid_chains = {}
for entity_id, seq_info in polymers.items():
chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(
[
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
for monomer in seq_info
]
):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in (".", "?")
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str,
_zero_center_positions: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f"Expected exactly one chain in structure with id {chain_id}."
)
chain = relevant_chains[0]
# Extract the coordinates
num_res = len(mmcif_object.chain_to_seqres[chain_id])
all_atom_positions = np.zeros(
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
)
all_atom_mask = np.zeros(
[num_res, residue_constants.atom_type_num], dtype=np.float32
)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[
(
res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code,
)
]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
if _zero_center_positions:
binary_mask = all_atom_mask.astype(bool)
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
all_atom_positions[binary_mask] -= translation_vec
return all_atom_positions, all_atom_mask
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for parsing various file formats."""
import collections
import dataclasses
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
sum_probs: float
query: str
hit_sequence: str
indices_query: List[int]
indices_hit: List[int]
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.
Arguments:
fasta_string: The string contents of a FASTA file.
Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith(">"):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append("")
continue
elif line.startswith("#"):
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line
return sequences, descriptions
def parse_stockholm(
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
"""
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(("#", "//")):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ""
name_to_sequence[name] += sequence
msa = []
deletion_matrix = []
query = ""
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences.
aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence)
# Count the number of deletions w.r.t. query.
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != "-" or query_res != "-":
if query_res == "-":
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys())
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.
Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences, _ = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
deletion_count = 0
for j in msa_sequence:
if j.islower():
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != "-":
yield sequence_res.lower()
def convert_stockholm_to_a3m(
stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
for line in stockholm_format.splitlines():
reached_max_sequences = (
max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ""
sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines():
if line[:4] == "#=GS":
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ""
if feature != "DE":
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break
# Convert sto format to a3m line by line
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f"Could not parse query line {line}")
return match.groups()
def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == "-":
indices_list.append(-1)
else:
indices_list.append(counter)
counter += 1
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"""Parses the detailed HMM HMM comparison section for a single Hit.
This works on .hhr files generated from both HHBlits and HHSearch.
Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)
Returns:
A dictionary with the information from that detailed comparison section
Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit = int(detailed_lines[0].split()[-1])
name_hit = detailed_lines[1][1:]
# Parse the summary line.
pattern = (
"Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
" ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
"]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
"Could not parse section: %s. Expected this: \n%s to contain summary."
% (detailed_lines, detailed_lines[2])
)
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ""
hit_sequence = ""
indices_query = []
indices_hit = []
length_block = None
for line in detailed_lines[3:]:
# Parse the query sequence line
if (
line.startswith("Q ")
and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions
assert length_block == len(delta_query)
# Update the query sequence and indices list.
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith("T "):
# Parse the hit sequence.
if (
not line.startswith("T ss_dssp")
and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
assert length_block == len(delta_hit_sequence)
# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit(
index=number_of_hit,
name=name_hit,
aligned_cols=int(aligned_cols),
sum_probs=sum_probs,
query=query,
hit_sequence=hit_sequence,
indices_query=indices_query,
indices_hit=indices_hit,
)
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
"""Parses the content of an entire HHR file."""
lines = hhr_string.splitlines()
# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for line in lines:
fields = line.split()
e_value = fields[4]
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for getting templates and calculating template features."""
import dataclasses
import datetime
import glob
import json
import logging
import os
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
import numpy as np
from openfold.data import parsers, mmcif_parsing
from openfold.data.errors import Error
from openfold.data.tools import kalign
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants
class NoChainsError(Error):
"""An error indicating that template mmCIF didn't have any chains."""
class SequenceNotInTemplateError(Error):
"""An error indicating that template mmCIF didn't contain the sequence."""
class NoAtomDataInTemplateError(Error):
"""An error indicating that template mmCIF didn't contain atom positions."""
class TemplateAtomMaskAllZerosError(Error):
"""An error indicating that template mmCIF had all atom positions masked."""
class QueryToTemplateAlignError(Error):
"""An error indicating that the query can't be aligned to the template."""
class CaDistanceError(Error):
"""An error indicating that a CA atom distance exceeds a threshold."""
# Prefilter exceptions.
class PrefilterError(Exception):
"""A base class for template prefilter exceptions."""
class DateError(PrefilterError):
"""An error indicating that the hit date was after the max allowed date."""
class PdbIdError(PrefilterError):
"""An error indicating that the hit PDB ID was identical to the query."""
class AlignRatioError(PrefilterError):
"""An error indicating that the hit align ratio to the query was too small."""
class DuplicateError(PrefilterError):
"""An error indicating that the hit was an exact subsequence of the query."""
class LengthError(PrefilterError):
"""An error indicating that the hit was too short."""
TEMPLATE_FEATURES = {
"template_aatype": np.int64,
"template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32,
"template_domain_names": np.object,
"template_sequence": np.object,
"template_sum_probs": np.float32,
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
if not id_match:
raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
pdb_id, chain_id = id_match.group(0).split("_")
return pdb_id.lower(), chain_id
def _is_after_cutoff(
pdb_id: str,
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: Optional[datetime.datetime],
) -> bool:
"""Checks if the template date is after the release date cutoff.
Args:
pdb_id: 4 letter pdb code.
release_dates: Dictionary mapping PDB ids to their structure release dates.
release_date_cutoff: Max release date that is valid for this query.
Returns:
True if the template release date is after the cutoff, False otherwise.
"""
pdb_id_upper = pdb_id.upper()
if release_date_cutoff is None:
raise ValueError("The release_date_cutoff must not be None.")
if pdb_id_upper in release_dates:
return release_dates[pdb_id_upper] > release_date_cutoff
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
logging.info(
"Template structure not in release dates dict: %s", pdb_id
)
return False
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
obsolete_new = {}
obsolete_keys = obsolete_mapping.keys()
def _new_target(k):
v = obsolete_mapping[k]
if v in obsolete_keys:
return _new_target(v)
return v
for k in obsolete_keys:
obsolete_new[k] = _new_target(k)
return obsolete_new
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f:
result = {}
for line in f:
line = line.strip()
# We skip obsolete entries that don't contain a mapping to a new entry.
if line.startswith("OBSLTE") and len(line) > 30:
# Format: Date From To
# 'OBSLTE 31-JUL-94 116L 216L'
from_id = line[20:24].lower()
to_id = line[29:33].lower()
result[from_id] = to_id
return _replace_obsolete_references(result)
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {}
for f in os.listdir(mmcif_dir):
if f.endswith(".cif"):
path = os.path.join(mmcif_dir, f)
with open(path, "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
if mmcif.mmcif_object is None:
logging.info(f"Failed to parse {f}. Skipping...")
continue
mmcif = mmcif.mmcif_object
release_date = mmcif.header["release_date"]
dates[file_id] = release_date
with open(out_path, "r") as fp:
fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
with open(path, "r") as fp:
data = json.load(fp)
return {
pdb.upper(): to_date(v)
for pdb, d in data.items()
for k, v in d.items()
if k == "release_date"
}
def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
query_sequence: str,
query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1,
) -> bool:
"""Determines if template is valid (without parsing the template mmcif file).
Args:
hit: HhrHit for the template.
hit_pdb_code: The 4 letter pdb code of the template hit. This might be
different from the value in the actual hit since the original pdb might
have become obsolete.
query_sequence: Amino acid sequence of the query.
query_pdb_code: 4 letter pdb code of the query.
release_dates: Dictionary mapping pdb codes to their structure release
dates.
release_date_cutoff: Max release date that is valid for this query.
max_subsequence_ratio: Exclude any exact matches with this much overlap.
min_align_ratio: Minimum overlap between the template and query.
Returns:
True if the hit passed the prefilter. Raises an exception otherwise.
Raises:
DateError: If the hit date was after the max allowed date.
PdbIdError: If the hit PDB ID was identical to the query.
AlignRatioError: If the hit align ratio to the query was too small.
DuplicateError: If the hit was an exact subsequence of the query.
LengthError: If the hit was too short.
"""
aligned_cols = hit.aligned_cols
align_ratio = aligned_cols / len(query_sequence)
template_sequence = hit.hit_sequence.replace("-", "")
length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
date = release_dates[hit_pdb_code.upper()]
raise DateError(
f"Date ({date}) > max template date "
f"({release_date_cutoff})."
)
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio:
raise AlignRatioError(
"Proportion of residues aligned to query too small. "
f"Align ratio: {align_ratio}."
)
if duplicate:
raise DuplicateError(
"Template is an exact subsequence of query with large "
f"coverage. Length ratio: {length_ratio}."
)
if len(template_sequence) < 10:
raise LengthError(
f"Template too short. Length: {len(template_sequence)}."
)
return True
def _find_template_in_pdb(
template_chain_id: str,
template_sequence: str,
mmcif_object: mmcif_parsing.MmcifObject,
) -> Tuple[str, str, int]:
"""Tries to find the template chain in the given pdb file.
This method tries the three following things in order:
1. Tries if there is an exact match in both the chain ID and the sequence.
If yes, the chain sequence is returned. Otherwise:
2. Tries if there is an exact match only in the sequence.
If yes, the chain sequence is returned. Otherwise:
3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
If yes, the chain sequence is returned.
If none of these succeed, a SequenceNotInTemplateError is thrown.
Args:
template_chain_id: The template chain ID.
template_sequence: The template chain sequence.
mmcif_object: The PDB object to search for the template in.
Returns:
A tuple with:
* The chain sequence that was found to match the template in the PDB object.
* The ID of the chain that is being returned.
* The offset where the template sequence starts in the chain sequence.
Raises:
SequenceNotInTemplateError: If no match is found after the steps described
above.
"""
# Try if there is an exact match in both the chain ID and the (sub)sequence.
pdb_id = mmcif_object.file_id
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
if chain_sequence and (template_sequence in chain_sequence):
logging.info(
"Found an exact template match %s_%s.", pdb_id, template_chain_id
)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, template_chain_id, mapping_offset
# Try if there is an exact match in the (sub)sequence only.
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
if chain_sequence and (template_sequence in chain_sequence):
logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, chain_id, mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
regex = re.compile("".join(regex))
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
match = re.search(regex, chain_sequence)
if match:
logging.info(
"Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
)
mapping_offset = match.start()
return chain_sequence, chain_id, mapping_offset
# No hits, raise an error.
raise SequenceNotInTemplateError(
"Could not find the template sequence in %s_%s. Template sequence: %s, "
"chain_to_seqres: %s"
% (
pdb_id,
template_chain_id,
template_sequence,
mmcif_object.chain_to_seqres,
)
)
def _realign_pdb_template_to_query(
old_template_sequence: str,
template_chain_id: str,
mmcif_object: mmcif_parsing.MmcifObject,
old_mapping: Mapping[int, int],
kalign_binary_path: str,
) -> Tuple[str, Mapping[int, int]]:
"""Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need
to perform a realignment to the actual sequence that is in the mmCIF file.
This method performs such realignment, but returns the new sequence and
mapping only if the sequence in the mmCIF file is 90% identical to the old
sequence.
Note that the old_template_sequence comes from the hit, and contains only that
part of the chain that matches with the query while the new_template_sequence
is the full chain.
Args:
old_template_sequence: The template sequence that was returned by the PDB
template search (typically done using HHSearch).
template_chain_id: The template chain id was returned by the PDB template
search (typically done using HHSearch). This is used to find the right
chain in the mmcif_object chain_to_seqres mapping.
mmcif_object: A mmcif_object which holds the actual template data.
old_mapping: A mapping from the query sequence to the template sequence.
This mapping will be used to compute the new mapping from the query
sequence to the actual mmcif_object template sequence by aligning the
old_template_sequence and the actual template sequence.
kalign_binary_path: The path to a kalign executable.
Returns:
A tuple (new_template_sequence, new_query_to_template_mapping) where:
* new_template_sequence is the actual template sequence that was found in
the mmcif_object.
* new_query_to_template_mapping is the new mapping from the query to the
actual template found in the mmcif_object.
Raises:
QueryToTemplateAlignError:
* If there was an error thrown by the alignment tool.
* Or if the actual template sequence differs by more than 10% from the
old_template_sequence.
"""
aligner = kalign.Kalign(binary_path=kalign_binary_path)
new_template_sequence = mmcif_object.chain_to_seqres.get(
template_chain_id, ""
)
# Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one.
if not new_template_sequence:
if len(mmcif_object.chain_to_seqres) == 1:
logging.info(
"Could not find %s in %s, but there is only 1 sequence, so "
"using that one.",
template_chain_id,
mmcif_object.file_id,
)
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
0
]
else:
raise QueryToTemplateAlignError(
f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
"If there are no mmCIF parsing errors, it is possible it was not a "
"protein chain."
)
try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence])
)
except Exception as e:
raise QueryToTemplateAlignError(
"Could not align old template %s to template %s (%s_%s). Error: %s"
% (
old_template_sequence,
new_template_sequence,
mmcif_object.file_id,
template_chain_id,
str(e),
)
)
logging.info(
"Old aligned template: %s\nNew aligned template: %s",
old_aligned_template,
new_aligned_template,
)
old_to_new_template_mapping = {}
old_template_index = -1
new_template_index = -1
num_same = 0
for old_template_aa, new_template_aa in zip(
old_aligned_template, new_aligned_template
):
if old_template_aa != "-":
old_template_index += 1
if new_template_aa != "-":
new_template_index += 1
if old_template_aa != "-" and new_template_aa != "-":
old_to_new_template_mapping[old_template_index] = new_template_index
if old_template_aa == new_template_aa:
num_same += 1
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
if (
float(num_same)
/ min(len(old_template_sequence), len(new_template_sequence))
< 0.9
):
raise QueryToTemplateAlignError(
"Insufficient similarity of the sequence in the database: %s to the "
"actual sequence in the mmCIF file %s_%s: %s. We require at least "
"90 %% similarity wrt to the shorter of the sequences. This is not a "
"problem unless you think this is a template that should be included."
% (
old_template_sequence,
mmcif_object.file_id,
template_chain_id,
new_template_sequence,
)
)
new_query_to_template_mapping = {}
for query_index, old_template_index in old_mapping.items():
new_query_to_template_mapping[
query_index
] = old_to_new_template_mapping.get(old_template_index, -1)
new_template_sequence = new_template_sequence.replace("-", "")
return new_template_sequence, new_query_to_template_mapping
def _check_residue_distances(
all_positions: np.ndarray,
all_positions_mask: np.ndarray,
max_ca_ca_distance: float,
):
"""Checks if the distance between unmasked neighbor residues is ok."""
ca_position = residue_constants.atom_order["CA"]
prev_is_unmasked = False
prev_calpha = None
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
this_is_unmasked = bool(mask[ca_position])
if this_is_unmasked:
this_calpha = coords[ca_position]
if prev_is_unmasked:
distance = np.linalg.norm(this_calpha - prev_calpha)
if distance > max_ca_ca_distance:
raise CaDistanceError(
"The distance between residues %d and %d is %f > limit %f."
% (i, i + 1, distance, max_ca_ca_distance)
)
prev_calpha = this_calpha
prev_is_unmasked = this_is_unmasked
def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float,
_zero_center_positions: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object,
chain_id=auth_chain_id,
_zero_center_positions=_zero_center_positions,
)
all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances(
all_atom_positions, all_atom_mask, max_ca_ca_distance
)
return all_atom_positions, all_atom_mask
def _extract_template_features(
mmcif_object: mmcif_parsing.MmcifObject,
pdb_id: str,
mapping: Mapping[int, int],
template_sequence: str,
query_sequence: str,
template_chain_id: str,
kalign_binary_path: str,
_zero_center_positions: bool = True,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide
with their corresponding residue in the query sequence, according to the
alignment mapping provided.
Args:
mmcif_object: mmcif_parsing.MmcifObject representing the template.
pdb_id: PDB code for the template.
mapping: Dictionary mapping indices in the query sequence to indices in
the template sequence.
template_sequence: String describing the amino acid sequence for the
template protein.
query_sequence: String describing the amino acid sequence for the query
protein.
template_chain_id: String ID describing which chain in the structure proto
should be used.
kalign_binary_path: The path to a kalign executable used for template
realignment.
Returns:
A tuple with:
* A dictionary containing the extra features derived from the template
protein structure.
* A warning message if the hit was realigned to the actual mmCIF sequence.
Otherwise None.
Raises:
NoChainsError: If the mmcif object doesn't contain any chains.
SequenceNotInTemplateError: If the given chain id / sequence can't
be found in the mmcif object.
QueryToTemplateAlignError: If the actual template in the mmCIF file
can't be aligned to the query.
NoAtomDataInTemplateError: If the mmcif object doesn't contain
atom positions.
TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
unmasked residues.
"""
if mmcif_object is None or not mmcif_object.chain_to_seqres:
raise NoChainsError(
"No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
)
warning = None
try:
seqres, chain_id, mapping_offset = _find_template_in_pdb(
template_chain_id=template_chain_id,
template_sequence=template_sequence,
mmcif_object=mmcif_object,
)
except SequenceNotInTemplateError:
# If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object.
chain_id = template_chain_id
warning = (
f"The exact sequence {template_sequence} was not found in "
f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
)
logging.warning(warning)
# This throws an exception if it fails to realign the hit.
seqres, mapping = _realign_pdb_template_to_query(
old_template_sequence=template_sequence,
template_chain_id=template_chain_id,
mmcif_object=mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path,
)
logging.info(
"Sequence in %s_%s: %s successfully realigned to %s",
pdb_id,
chain_id,
template_sequence,
seqres,
)
# The template sequence changed.
template_sequence = seqres
# No mapping offset, the query is aligned to the actual sequence.
mapping_offset = 0
try:
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object,
chain_id,
max_ca_ca_distance=150.0,
_zero_center_positions=_zero_center_positions,
)
except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError(
"Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
) from ex
all_atom_positions = np.split(
all_atom_positions, all_atom_positions.shape[0]
)
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
output_templates_sequence = []
templates_all_atom_positions = []
templates_all_atom_masks = []
for _ in query_sequence:
# Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions.append(
np.zeros((residue_constants.atom_type_num, 3))
)
templates_all_atom_masks.append(
np.zeros(residue_constants.atom_type_num)
)
output_templates_sequence.append("-")
for k, v in mapping.items():
template_index = v + mapping_offset
templates_all_atom_positions[k] = all_atom_positions[template_index][0]
templates_all_atom_masks[k] = all_atom_masks[template_index][0]
output_templates_sequence[k] = template_sequence[v]
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if np.sum(templates_all_atom_masks) < 5:
raise TemplateAtomMaskAllZerosError(
"Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
% (
pdb_id,
chain_id,
min(mapping.values()) + mapping_offset,
max(mapping.values()) + mapping_offset,
)
)
output_templates_sequence = "".join(output_templates_sequence)
templates_aatype = residue_constants.sequence_to_onehot(
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
)
return (
{
"template_all_atom_positions": np.array(
templates_all_atom_positions
),
"template_all_atom_mask": np.array(templates_all_atom_masks),
"template_sequence": output_templates_sequence.encode(),
"template_aatype": np.array(templates_aatype),
"template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
},
warning,
)
def _build_query_to_hit_index_mapping(
hit_query_sequence: str,
hit_sequence: str,
indices_hit: Sequence[int],
indices_query: Sequence[int],
original_query_sequence: str,
) -> Mapping[int, int]:
"""Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap
characters. hit_query_sequence contains only the part of the original query
sequence that matched the hit. When interpreting the indices from the .hhr, we
need to correct for this to recover a mapping from original query sequence to
the hit sequence.
Args:
hit_query_sequence: The portion of the query sequence that is in the .hhr
hit
hit_sequence: The portion of the hit sequence that is in the .hhr
indices_hit: The indices for each aminoacid relative to the hit sequence
indices_query: The indices for each aminoacid relative to the original query
sequence
original_query_sequence: String describing the original query sequence.
Returns:
Dictionary with indices in the original query sequence as keys and indices
in the hit sequence as values.
"""
# If the hit is empty (no aligned residues), return empty mapping
if not hit_query_sequence:
return {}
# Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence = hit_query_sequence.replace("-", "")
hit_sequence = hit_sequence.replace("-", "")
hhsearch_query_offset = original_query_sequence.find(
hhsearch_query_sequence
)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx = min(x for x in indices_hit if x > -1)
fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
min_idx = min(x for x in indices_query if x > -1)
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
# Zip the corrected indices, ignore case where both seqs have gap characters.
mapping = {}
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
if q_t != -1 and q_i != -1:
if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
original_query_sequence
):
continue
mapping[q_i + hhsearch_query_offset] = q_t
return mapping
@dataclasses.dataclass(frozen=True)
class PrefilterResult:
valid: bool
error: Optional[str]
warning: Optional[str]
@dataclasses.dataclass(frozen=True)
class SingleHitResult:
features: Optional[Mapping[str, Any]]
error: Optional[str]
warning: Optional[str]
def _prefilter_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
strict_error_check: bool = False,
):
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
if hit_pdb_code not in release_dates:
if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
# Pass hit_pdb_code since it might have changed due to the pdb being
# obsolete.
try:
_assess_hhsearch_hit(
hit=hit,
hit_pdb_code=hit_pdb_code,
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates,
release_date_cutoff=max_template_date,
)
except PrefilterError as e:
hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors.
return PrefilterResult(valid=False, error=msg, warning=None)
return PrefilterResult(valid=False, error=None, warning=None)
return PrefilterResult(valid=True, error=None, warning=None)
def _process_single_hit(
query_sequence: str,
query_pdb_code: Optional[str],
hit: parsers.TemplateHit,
mmcif_dir: str,
max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False,
_zero_center_positions: bool = True,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
if hit_pdb_code not in release_dates:
if hit_pdb_code in obsolete_pdbs:
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
mapping = _build_query_to_hit_index_mapping(
hit.query,
hit.hit_sequence,
hit.indices_hit,
hit.indices_query,
query_sequence,
)
# The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score).
template_sequence = hit.hit_sequence.replace("-", "")
cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
logging.info(
"Reading PDB entry from %s. Query: %s, template: %s",
cif_path,
query_sequence,
template_sequence,
)
# Fail if we can't find the mmCIF file.
with open(cif_path, "r") as cif_file:
cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string
)
if parsing_result.mmcif_object is not None:
hit_release_date = datetime.datetime.strptime(
parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
)
if hit_release_date > max_template_date:
error = "Template %s date (%s) > max template date (%s)." % (
hit_pdb_code,
hit_release_date,
max_template_date,
)
if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None)
else:
logging.info(error)
return SingleHitResult(features=None, error=None, warning=None)
try:
features, realign_warning = _extract_template_features(
mmcif_object=parsing_result.mmcif_object,
pdb_id=hit_pdb_code,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=_zero_center_positions,
)
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant.
return SingleHitResult(
features=features, error=None, warning=realign_warning
)
except (
NoChainsError,
NoAtomDataInTemplateError,
TemplateAtomMaskAllZerosError,
) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None)
else:
return SingleHitResult(features=None, error=None, warning=warning)
except Error as e:
error = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
return SingleHitResult(features=None, error=error, warning=None)
def get_custom_template_features(
mmcif_path: str,
query_sequence: str,
pdb_id: str,
chain_id: str,
kalign_binary_path: str):
with open(mmcif_path, "r") as mmcif_path:
cif_string = mmcif_path.read()
mmcif_parse_result = mmcif_parsing.parse(
file_id=pdb_id, mmcif_string=cif_string
)
template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
mapping = {x:x for x, _ in enumerate(query_sequence)}
features, warnings = _extract_template_features(
mmcif_object=mmcif_parse_result.mmcif_object,
pdb_id=pdb_id,
mapping=mapping,
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=chain_id,
kalign_binary_path=kalign_binary_path,
_zero_center_positions=True
)
features["template_sum_probs"] = [1.0]
# TODO: clean up this logic
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
for k in template_features:
template_features[k].append(features[k])
for name in template_features:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
return TemplateSearchResult(
features=template_features, errors=None, warnings=warnings
)
@dataclasses.dataclass(frozen=True)
class TemplateSearchResult:
features: Mapping[str, Any]
errors: Sequence[str]
warnings: Sequence[str]
class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features."""
def __init__(
self,
mmcif_dir: str,
max_template_date: str,
max_hits: int,
kalign_binary_path: str,
release_dates_path: Optional[str] = None,
obsolete_pdbs_path: Optional[str] = None,
strict_error_check: bool = False,
_shuffle_top_k_prefiltered: Optional[int] = None,
_zero_center_positions: bool = True,
):
"""Initializes the Template Search.
Args:
mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
is found by HHSearch, this directory is used to retrieve the template
data.
max_template_date: The maximum date permitted for template structures. No
template with date higher than this date will be returned. In ISO8601
date format, YYYY-MM-DD.
max_hits: The maximum number of templates that will be returned.
kalign_binary_path: The path to a kalign executable used for template
realignment.
release_dates_path: An optional path to a file with a mapping from PDB IDs
to their release dates. Thanks to this we don't have to redundantly
parse mmCIF files to get that information.
obsolete_pdbs_path: An optional path to a file containing a mapping from
obsolete PDB IDs to the PDB IDs of their replacements.
strict_error_check: If True, then the following will be treated as errors:
* If any template date is after the max_template_date.
* If any template has identical PDB ID to the query.
* If any template is a duplicate of the query.
* Any feature computation errors.
"""
self._mmcif_dir = mmcif_dir
if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
logging.error("Could not find CIFs in %s", self._mmcif_dir)
raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
try:
self._max_template_date = datetime.datetime.strptime(
max_template_date, "%Y-%m-%d"
)
except ValueError:
raise ValueError(
"max_template_date must be set and have format YYYY-MM-DD."
)
self.max_hits = max_hits
self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check
if release_dates_path:
logging.info(
"Using precomputed release dates %s.", release_dates_path
)
self._release_dates = _parse_release_dates(release_dates_path)
else:
self._release_dates = {}
if obsolete_pdbs_path:
logging.info(
"Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
)
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
else:
self._obsolete_pdbs = {}
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
self._zero_center_positions = _zero_center_positions
def get_templates(
self,
query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info("Searching for template for: %s", query_pdb_code)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = []
# Always use a max_template_date. Set to query_release_date minus 60 days
# if that's earlier.
template_cutoff_date = self._max_template_date
if query_release_date:
delta = datetime.timedelta(days=60)
if query_release_date - delta < template_cutoff_date:
template_cutoff_date = query_release_date - delta
assert template_cutoff_date < query_release_date
assert template_cutoff_date <= self._max_template_date
num_hits = 0
errors = []
warnings = []
filtered = []
for hit in hits:
prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
)
if prefilter_result.error:
errors.append(prefilter_result.error)
if prefilter_result.warning:
warnings.append(prefilter_result.warning)
if prefilter_result.valid:
filtered.append(hit)
filtered = list(
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
# We got all the templates we wanted, stop processing hits.
if num_hits >= self.max_hits:
break
hit = filtered[i]
result = _process_single_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit,
mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path,
_zero_center_positions=self._zero_center_positions,
)
if result.error:
errors.append(result.error)
# There could be an error even if there are some results, e.g. thrown by
# other unparsable chains in the same mmCIF file.
if result.warning:
warnings.append(result.warning)
if result.features is None:
logging.info(
"Skipped invalid hit %s, error: %s, warning: %s",
hit.name,
result.error,
result.warning,
)
else:
# Increment the hit counter, since we got features out of this hit.
num_hits += 1
for k in template_features:
template_features[k].append(result.features[k])
for name in template_features:
if num_hits > 0:
template_features[name] = np.stack(
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
# Make sure the feature has correct dtype even if empty.
template_features[name] = np.array(
[], dtype=TEMPLATE_FEATURES[name]
)
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHblits from Python."""
import glob
import logging
import os
import subprocess
from typing import Any, Mapping, Optional, Sequence
from openfold.data.tools import utils
_HHBLITS_DEFAULT_P = 20
_HHBLITS_DEFAULT_Z = 500
class HHBlits:
"""Python wrapper of the HHblits binary."""
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 4,
n_iter: int = 3,
e_value: float = 0.001,
maxseq: int = 1_000_000,
realign_max: int = 100_000,
maxfilt: int = 100_000,
min_prefilter_hits: int = 1000,
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z,
):
"""Initializes the Python HHblits wrapper.
Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
for database_path in self.databases:
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHBlits database %s", database_path
)
raise ValueError(
f"Could not find HHBlits database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.maxseq = maxseq
self.realign_max = realign_max
self.maxfilt = maxfilt
self.min_prefilter_hits = min_prefilter_hits
self.all_seqs = all_seqs
self.alt = alt
self.p = p
self.z = z
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = []
for db_path in self.databases:
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
"-i",
input_fasta_path,
"-cpu",
str(self.n_cpu),
"-oa3m",
a3m_path,
"-o",
"/dev/null",
"-n",
str(self.n_iter),
"-e",
str(self.e_value),
"-maxseq",
str(self.maxseq),
"-realign_max",
str(self.realign_max),
"-maxfilt",
str(self.maxfilt),
"-min_prefilter_hits",
str(self.min_prefilter_hits),
]
if self.all_seqs:
cmd += ["-all"]
if self.alt:
cmd += ["-alt", str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ["-p", str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ["-Z", str(self.z)]
cmd += db_cmd
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHblits query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
logging.error("HHblits failed. HHblits stderr begin:")
for error_line in stderr.decode("utf-8").splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error("HHblits stderr end")
raise RuntimeError(
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
)
with open(a3m_path) as f:
a3m = f.read()
raw_output = dict(
a3m=a3m,
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value,
)
return raw_output
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHsearch from Python."""
import glob
import logging
import os
import subprocess
from typing import Sequence
from openfold.data.tools import utils
class HHSearch:
"""Python wrapper of the HHsearch binary."""
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper.
Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
self.n_cpu = n_cpu
self.maxseq = maxseq
for database_path in self.databases:
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f:
f.write(a3m)
db_cmd = []
for db_path in self.databases:
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
"-i",
input_path,
"-o",
hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHsearch query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f:
hhr = f.read()
return hhr
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run Jackhmmer from Python."""
from concurrent import futures
import glob
import logging
import os
import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from openfold.data.tools import utils
class Jackhmmer:
"""Python wrapper of the Jackhmmer binary."""
def __init__(
self,
*,
binary_path: str,
database_path: str,
n_cpu: int = 8,
n_iter: int = 1,
e_value: float = 0.0001,
z_value: Optional[int] = None,
get_tblout: bool = False,
filter_f1: float = 0.0005,
filter_f2: float = 0.00005,
filter_f3: float = 0.0000005,
incdom_e: Optional[float] = None,
dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None,
):
"""Initializes the Python Jackhmmer wrapper.
Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value, see Jackhmmer docs for more details.
get_tblout: Whether to save tblout string.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
dom_e: Domain e-value criteria for inclusion in tblout.
num_streamed_chunks: Number of database chunks to stream over.
streaming_callback: Callback function run after each chunk iteration with
the iteration number as argument.
"""
self.binary_path = binary_path
self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks
if (
not os.path.exists(self.database_path)
and num_streamed_chunks is None
):
logging.error("Could not find Jackhmmer database %s", database_path)
raise ValueError(
f"Could not find Jackhmmer database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.z_value = z_value
self.filter_f1 = filter_f1
self.filter_f2 = filter_f2
self.filter_f3 = filter_f3
self.incdom_e = incdom_e
self.dom_e = dom_e
self.get_tblout = get_tblout
self.streaming_callback = streaming_callback
def _query_chunk(
self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags = [
# Don't pollute stdout with Jackhmmer output.
"-o",
"/dev/null",
"-A",
sto_path,
"--noali",
"--F1",
str(self.filter_f1),
"--F2",
str(self.filter_f2),
"--F3",
str(self.filter_f3),
"--incE",
str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output.
"-E",
str(self.e_value),
"--cpu",
str(self.n_cpu),
"-N",
str(self.n_iter),
]
if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
cmd_flags.extend(["--tblout", tblout_path])
if self.z_value:
cmd_flags.extend(["-Z", str(self.z_value)])
if self.dom_e is not None:
cmd_flags.extend(["--domE", str(self.dom_e)])
if self.incdom_e is not None:
cmd_flags.extend(["--incdomE", str(self.incdom_e)])
cmd = (
[self.binary_path]
+ cmd_flags
+ [input_fasta_path, database_path]
)
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing(
f"Jackhmmer ({os.path.basename(database_path)}) query"
):
_, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
"Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
)
# Get e-values for each target name
tbl = ""
if self.get_tblout:
with open(tblout_path) as f:
tbl = f.read()
with open(sto_path) as f:
sto = f.read()
raw_output = dict(
sto=sto,
tbl=tbl,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value,
)
return raw_output
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
# Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk("[0-9]*")):
try:
os.remove(f)
except OSError:
print(f"OSError while deleting {f}")
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
future = executor.submit(
request.urlretrieve,
db_remote_chunk(i),
db_local_chunk(i),
)
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve,
db_remote_chunk(i + 1),
db_local_chunk(i + 1),
)
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i))
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for Kalign."""
import os
import subprocess
from typing import Sequence
from absl import logging
from openfold.data.tools import utils
def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append(u">" + name + u"\n")
a3m.append(sequence + u"\n")
return "".join(a3m)
class Kalign:
"""Python wrapper of the Kalign binary."""
def __init__(self, *, binary_path: str):
"""Initializes the Python Kalign wrapper.
Args:
binary_path: The path to the Kalign binary.
Raises:
RuntimeError: If Kalign binary not found within the path.
"""
self.binary_path = binary_path
def align(self, sequences: Sequence[str]) -> str:
"""Aligns the sequences and returns the alignment in A3M string.
Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.
Returns:
A string with the alignment in a3m format.
Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info("Aligning %d sequences", len(sequences))
for s in sequences:
if len(s) < 6:
raise ValueError(
"Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
with open(input_fasta_path, "w") as f:
f.write(_to_a3m(sequences))
cmd = [
self.binary_path,
"-i",
input_fasta_path,
"-o",
output_a3m_path,
"-format",
"fasta",
]
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("Kalign query"):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info(
"Kalign stdout:\n%s\n\nstderr:\n%s\n",
stdout.decode("utf-8"),
stderr.decode("utf-8"),
)
if retcode:
raise RuntimeError(
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(output_a3m_path) as f:
a3m = f.read()
return a3m
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import contextlib
import datetime
import logging
import shutil
import tempfile
import time
from typing import Optional
@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager
def timing(msg: str):
logging.info("Started %s", msg)
tic = time.perf_counter()
yield
toc = time.perf_counter()
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
def to_date(s: str):
return datetime.datetime(
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
)
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from functools import partialmethod
from typing import Union, List
class Dropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask
along a particular dimension.
If not in training mode, this module computes the identity function.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
"""
Args:
r:
Dropout rate
batch_dim:
Dimension(s) along which the dropout mask is shared
"""
super(Dropout, self).__init__()
self.r = r
if type(batch_dim) == int:
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape)
mask = self.dropout(mask)
x *= mask
return x
class DropoutRowwise(Dropout):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
class DropoutColumnwise(Dropout):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple, Optional
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, one_hot
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = self.relpos(ri.type(tf_emb_i.dtype))
pair_emb = add(pair_emb,
tf_emb_i[..., None, :],
inplace=inplace_safe
)
pair_emb = add(pair_emb,
tf_emb_j[..., None, :, :],
inplace=inplace_safe
)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
# [*, N, C_m]
m_update = self.layer_norm_m(m)
if(inplace_safe):
m.copy_(m_update)
m_update = m
# [*, N, N, C_z]
z_update = self.layer_norm_z(z)
if(inplace_safe):
z.copy_(z_update)
z_update = z
# This squared method might become problematic in FP16 mode.
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
squared_bins = bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = add(z_update, d, inplace_safe)
return m_update, z_update
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
import torch
import torch.nn as nn
from typing import Tuple, Sequence, Optional
from functools import partial
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
)
from openfold.model.outer_product_mean import OuterProductMean
from openfold.model.pair_transition import PairTransition
from openfold.model.triangular_attention import (
TriangleAttention,
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from openfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer, ChunkSizeTuner
from openfold.utils.tensor_utils import add
class MSATransition(nn.Module):
"""
Feed-forward network applied to MSA activations after attention.
Implements Algorithm 9
"""
def __init__(self, c_m, n):
"""
Args:
c_m:
MSA channel dimension
n:
Factor multiplied to c_m to obtain the hidden channel
dimension
"""
super(MSATransition, self).__init__()
self.c_m = c_m
self.n = n
self.layer_norm = LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask):
m = self.layer_norm(m)
m = self.linear_1(m)
m = self.relu(m)
m = self.linear_2(m) * mask
return m
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"m": m, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA activation
mask:
[*, N_seq, N_res, C_m] MSA mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if mask is None:
mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
m = self._transition(m, mask)
return m
class EvoformerBlockCore(nn.Module):
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
pair_dropout: float,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerBlockCore, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
c_z,
c_hidden_mul,
)
self.tri_att_start = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.tri_att_end = TriangleAttention(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.pair_transition = PairTransition(
c_z,
transition_n,
)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
def forward(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
m, z = input_tensors
m = add(
m,
self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
del m, z
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
opm = self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size, inplace_safe=inplace_safe
)
if(_offload_inference and inplace_safe):
del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
z = add(z, opm, inplace=inplace_safe)
del opm
tmu_update = self.tri_mul_out(
z,
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
del tmu_update
tmu_update = self.tri_mul_in(
z,
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
else:
z = tmu_update
del tmu_update
z = add(z,
self.ps_dropout_row_layer(
self.tri_att_start(
z,
mask=pair_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.ps_dropout_row_layer(
self.tri_att_end(
z,
mask=pair_mask.transpose(-1, -2),
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace=inplace_safe,
)
z = z.transpose(-2, -3)
if(inplace_safe):
input_tensors[1] = z.contiguous()
z = input_tensors[1]
z = add(z,
self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size,
),
inplace=inplace_safe,
)
if(_offload_inference and inplace_safe):
device = z.device
del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
return m, z
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m,
z=z,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_lma=use_lma,
)
),
inplace=inplace_safe,
)
m = add(m,
self.msa_att_col(
m,
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
del m, z
m, z = self.core(
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
return m, z
class ExtraMSABlock(nn.Module):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: Optional[torch.Tensor],
z: Optional[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
_attn_chunk_size: Optional[int] = None,
_offload_inference: bool = False,
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
else:
input_tensors = [m, z]
m, z = input_tensors
m = add(m,
self.msa_dropout_layer(
self.msa_att_row(
m.clone() if torch.is_grad_enabled() else m,
z=z.clone() if torch.is_grad_enabled() else z,
mask=msa_mask,
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_memory_efficient_kernel=not use_lma,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, z]
del m, z
def fn(input_tensors):
m = add(input_tensors[0],
self.msa_att_col(
input_tensors[0],
mask=msa_mask,
chunk_size=chunk_size,
use_lma=use_lma,
),
inplace=inplace_safe,
)
if(not inplace_safe):
input_tensors = [m, input_tensors[1]]
del m
m, z = self.core(
input_tensors,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
_attn_chunk_size=_attn_chunk_size,
_offload_inference=_offload_inference,
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, input_tensors)
else:
m, z = fn(input_tensors)
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
c_s: int,
no_heads_msa: int,
no_heads_pair: int,
no_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
tune_chunk_size: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
tune_chunk_size:
Whether to dynamically tune the module's chunk size
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = EvoformerBlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args, **kwargs):
torch.cuda.empty_cache()
return block(*args, **kwargs)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
assert(not self.training)
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# We don't want to write in-place during chunk tuning runs
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
m, z = input_tensors
s = self.linear(m[..., 0, :, :])
return m, z, s
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
use_flash:
Whether to use FlashAttention where possible. Mutually
exclusive with use_lma.
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
blocks_per_ckpt = self.blocks_per_ckpt
if(not torch.is_grad_enabled()):
blocks_per_ckpt = None
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=blocks_per_ckpt,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
no_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
tune_chunk_size: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.ckpt = ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=False,
)
self.blocks.append(block)
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
inplace_safe: bool,
_mask_trans: bool,
):
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
) for b in self.blocks
]
def clear_cache(b, *args, **kwargs):
torch.cuda.empty_cache()
return b(*args, **kwargs)
if(self.clear_cache_between_blocks):
blocks = [partial(clear_cache, b) for b in blocks]
if(chunk_size is not None and self.chunk_size_tuner is not None):
tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
representative_fn=blocks[0],
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
) for b in blocks
]
return blocks
def _forward_offload(self,
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
assert(not (self.training or torch.is_grad_enabled()))
blocks = self._prep_blocks(
# We are very careful not to create references to these tensors in
# this function
m=input_tensors[0],
z=input_tensors[1],
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=True,
_mask_trans=_mask_trans,
)
for b in blocks:
m, z = b(
None,
None,
_offload_inference=True,
_offloadable_inputs=input_tensors,
)
input_tensors[0] = m
input_tensors[1] = z
del m, z
return input_tensors[1]
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
chunk_size: Inference-time subbatch size for Evoformer modules
use_lma: Whether to use low-memory attention during inference
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
checkpoint_fn = get_checkpoint_fn()
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
for b in blocks:
if(self.ckpt and torch.is_grad_enabled()):
m, z = checkpoint_fn(b, m, z)
else:
m, z = b(m, z)
return z
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.loss import (
compute_plddt,
compute_tm,
compute_predicted_aligned_error,
)
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module):
def __init__(self, config):
super(AuxiliaryHeads, self).__init__()
self.plddt = PerResidueLDDTCaPredictor(
**config["lddt"],
)
self.distogram = DistogramHead(
**config["distogram"],
)
self.masked_msa = MaskedMSAHead(
**config["masked_msa"],
)
self.experimentally_resolved = ExperimentallyResolvedHead(
**config["experimentally_resolved"],
)
if config.tm.enabled:
self.tm = TMScoreHead(
**config.tm,
)
self.config = config
def forward(self, outputs):
aux_out = {}
lddt_logits = self.plddt(outputs["sm"]["single"])
aux_out["lddt_logits"] = lddt_logits
# Required for relaxation later on
aux_out["plddt"] = compute_plddt(lddt_logits)
distogram_logits = self.distogram(outputs["pair"])
aux_out["distogram_logits"] = distogram_logits
masked_msa_logits = self.masked_msa(outputs["msa"])
aux_out["masked_msa_logits"] = masked_msa_logits
experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"]
)
aux_out[
"experimentally_resolved_logits"
] = experimentally_resolved_logits
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
**self.config.tm,
)
)
return aux_out
class PerResidueLDDTCaPredictor(nn.Module):
def __init__(self, no_bins, c_in, c_hidden):
super(PerResidueLDDTCaPredictor, self).__init__()
self.no_bins = no_bins
self.c_in = c_in
self.c_hidden = c_hidden
self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s = self.layer_norm(s)
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
return s
class DistogramHead(nn.Module):
"""
Computes a distogram probability distribution.
For use in computation of distogram loss, subsection 1.9.8
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of distogram bins
"""
super(DistogramHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def _forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N, N, no_bins] distogram probability distribution
"""
# [*, N, N, no_bins]
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
def forward(self, z):
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of bins
"""
super(TMScoreHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z):
"""
Args:
z:
[*, N_res, N_res, C_z] pairwise embedding
Returns:
[*, N_res, N_res, no_bins] prediction
"""
# [*, N, N, no_bins]
logits = self.linear(z)
return logits
class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def __init__(self, c_m, c_out, **kwargs):
"""
Args:
c_m:
MSA channel dimension
c_out:
Output channel dimension
"""
super(MaskedMSAHead, self).__init__()
self.c_m = c_m
self.c_out = c_out
self.linear = Linear(self.c_m, self.c_out, init="final")
def forward(self, m):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
Returns:
[*, N_seq, N_res, C_out] reconstruction
"""
# [*, N_seq, N_res, C_out]
logits = self.linear(m)
return logits
class ExperimentallyResolvedHead(nn.Module):
"""
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def __init__(self, c_s, c_out, **kwargs):
"""
Args:
c_s:
Input channel dimension
c_out:
Number of distogram bins
"""
super(ExperimentallyResolvedHead, self).__init__()
self.c_s = c_s
self.c_out = c_out
self.linear = Linear(self.c_s, self.c_out, init="final")
def forward(self, s):
"""
Args:
s:
[*, N_res, C_s] single embedding
Returns:
[*, N, C_out] logits
"""
# [*, N, C_out]
logits = self.linear(s)
return logits
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import weakref
import torch
import torch.nn as nn
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder,
)
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
embed_templates_average,
embed_templates_offload,
)
import openfold.np.residue_constants as residue_constants
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from openfold.utils.loss import (
compute_plddt,
)
from openfold.utils.tensor_utils import (
add,
dict_multimap,
tensor_tree_map,
)
class AlphaFold(nn.Module):
"""
Alphafold 2.
Implements Algorithm 2 (but with training).
"""
def __init__(self, config):
"""
Args:
config:
A dict-like config object (like the one in config.py)
"""
super(AlphaFold, self).__init__()
self.globals = config.globals
self.config = config.model
self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module
self.input_embedder = InputEmbedder(
**self.config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**self.config["recycling_embedder"],
)
if(self.template_config.enabled):
self.template_angle_embedder = TemplateAngleEmbedder(
**self.template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**self.template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**self.template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**self.template_config["template_pointwise_attention"],
)
if(self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**self.extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**self.config["evoformer_stack"],
)
self.structure_module = StructureModule(
**self.config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
self.config["heads"],
)
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
if(self.template_config.offload_templates):
return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
elif(self.template_config.average_templates):
return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
)
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if(inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.globals.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.template.use_unit_vector,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if(inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if(not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=self.globals.use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if(inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
return ret
def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary
outputs = {}
# This needs to be done manually for DeepSpeed's sake
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe = not (self.training or torch.is_grad_enabled())
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory
m_1_prev, z_prev, x_prev = reversed([prevs.pop() for _ in range(3)])
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# The recycling embedder is memory-intensive, so we offload first
if(self.globals.offload_inference and inplace_safe):
m = m.cpu()
z = z.cpu()
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
inplace_safe=inplace_safe,
)
if(self.globals.offload_inference and inplace_safe):
m = m.to(m_1_prev_emb.device)
z = z.to(z_prev.device)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z = add(z, z_prev_emb, inplace=inplace_safe)
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
inplace_safe=inplace_safe,
)
# [*, N, N, C_z]
z = add(z,
template_embeds.pop("template_pair_embedding"),
inplace_safe,
)
if "template_angle_embedding" in template_embeds:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors = [a, z]
del a, z
# [*, N, N, C_z]
z = self.extra_msa_stack._forward_offload(
input_tensors,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
# [*, N, N, C_z]
z = self.extra_msa_stack(
a, z,
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if(self.globals.offload_inference):
input_tensors = [m, z]
del m, z
m, z, s = self.evoformer._forward_offload(
input_tensors,
msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
del input_tensors
else:
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
del z
# Predict 3D structure
outputs["sm"] = self.structure_module(
outputs,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
inplace_safe=inplace_safe,
_offload_inference=self.globals.offload_inference,
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = outputs["pair"]
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
prevs = [m_1_prev, z_prev, x_prev]
is_grad_enabled = torch.is_grad_enabled()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
prevs,
_recycle=(num_iters > 1)
)
if(not is_final_iter):
del outputs
prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
return outputs
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
import torch
import torch.nn as nn
from typing import Optional, List, Tuple
from openfold.model.primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
)
class MSAAttention(nn.Module):
def __init__(
self,
c_in,
c_hidden,
no_heads,
pair_bias=False,
c_z=None,
inf=1e9,
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
pair_bias:
Whether to use pair embedding bias
c_z:
Pair embedding channel dimension. Ignored unless pair_bias
is true
inf:
A large number to be used in computing the attention mask
"""
super(MSAAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.pair_bias = pair_bias
self.c_z = c_z
self.inf = inf
self.layer_norm_m = LayerNorm(self.c_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias:
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = Attention(
self.c_in,
self.c_in,
self.c_in,
self.c_hidden,
self.no_heads,
)
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: Optional[List[torch.Tensor]],
chunk_size: int,
use_memory_efficient_kernel: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
) -> torch.Tensor:
def fn(m, biases, flash_mask):
m = self.layer_norm_m(m)
return self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
)
inputs = {"m": m}
if(biases is not None):
inputs["biases"] = biases
else:
fn = partial(fn, biases=None)
if(use_flash and flash_mask is not None):
inputs["flash_mask"] = flash_mask
else:
fn = partial(fn, flash_mask=None)
return chunk_layer(
fn,
inputs,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2])
)
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res),
)
# [*, N_seq, 1, 1, N_res]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
if (self.pair_bias and
z is not None and # For the
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
chunks = []
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
inplace_safe: bool = False
) -> torch.Tensor:
"""
MSA attention with training-time chunking of the softmax computation.
Saves memory in the extra MSA stack. Probably obviated by our fused
attention kernel, which is now used by default.
"""
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
m = self.layer_norm_m(m)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(torch.is_grad_enabled() and checkpoint):
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits,
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(
m,
biases,
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
else:
m = self.layer_norm_m(m)
m = self.mha(
q_x=m,
kv_x=m,
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
)
return m
class MSARowAttentionWithPairBias(MSAAttention):
"""
Implements Algorithm 7.
"""
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
Input channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super(MSARowAttentionWithPairBias, self).__init__(
c_m,
c_hidden,
no_heads,
pair_bias=True,
c_z=c_z,
inf=inf,
)
class MSAColumnAttention(nn.Module):
"""
Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
"""
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
MSA channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super(MSAColumnAttention, self).__init__()
self.c_m = c_m
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self._msa_att = MSAAttention(
c_in=c_m,
c_hidden=c_hidden,
no_heads=no_heads,
pair_bias=False,
c_z=None,
inf=inf,
)
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(
m,
mask=mask,
chunk_size=chunk_size,
use_lma=use_lma,
use_flash=use_flash,
)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
if mask is not None:
mask = mask.transpose(-1, -2)
return m
class MSAColumnGlobalAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
):
super(MSAColumnGlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.layer_norm_m = nn.LayerNorm(c_in)
self.global_attention = GlobalAttention(
c_in=c_in,
c_hidden=c_hidden,
no_heads=no_heads,
inf=inf,
eps=eps,
)
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
use_lma: bool = False,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
def fn(m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask, use_lma=use_lma)
return chunk_layer(
fn,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_lma: bool = False,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
if mask is None:
# [*, N_seq, N_res]
mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size, use_lma=use_lma)
else:
m = self.layer_norm_m(m)
m = self.global_attention(m=m, mask=mask, use_lma=use_lma)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
return m
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module):
"""
Implements Algorithm 10.
"""
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
"""
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(OuterProductMean, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.c_hidden = c_hidden
self.eps = eps
self.layer_norm = nn.LayerNorm(c_m)
self.linear_1 = Linear(c_m, c_hidden)
self.linear_2 = Linear(c_m, c_hidden)
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
def _opm(self, a, b):
# [*, N_res, N_res, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b)
# [*, N_res, N_res, C * C]
outer = outer.reshape(outer.shape[:-2] + (-1,))
# [*, N_res, N_res, C_z]
outer = self.linear_out(outer)
return outer
@torch.jit.ignore
def _chunk(self,
a: torch.Tensor,
b: torch.Tensor,
chunk_size: int
) -> torch.Tensor:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape = a.reshape((-1,) + a.shape[-3:])
b_reshape = b.reshape((-1,) + b.shape[-3:])
out = []
for a_prime, b_prime in zip(a_reshape, b_reshape):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=chunk_size,
no_batch_dims=1,
)
out.append(outer)
# For some cursed reason making this distinction saves memory
if(len(out) == 1):
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def _forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if mask is None:
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
ln = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(ln)
a = a * mask
b = self.linear_2(ln)
b = b * mask
del ln
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if chunk_size is not None:
outer = self._chunk(a, b, chunk_size)
else:
outer = self._opm(a, b)
# [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
norm = norm + self.eps
# [*, N_res, N_res, C_z]
if(inplace_safe):
outer /= norm
else:
outer = outer / norm
return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
inplace_safe: bool = False,
) -> torch.Tensor:
if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe)
else:
return self._forward(m, mask, chunk_size, inplace_safe)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.chunk_utils import chunk_layer
class PairTransition(nn.Module):
"""
Implements Algorithm 15.
"""
def __init__(self, c_z, n):
"""
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
"""
super(PairTransition, self).__init__()
self.c_z = c_z
self.n = n
self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z)
z = z * mask
return z
@torch.jit.ignore
def _chunk(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"z": z, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if mask is None:
mask = z.new_ones(z.shape[:-1])
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
z = self._transition(z=z, mask=mask)
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment