Unverified Commit 4693058b authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

support multimer (#63)

parent c80a4df5
......@@ -34,6 +34,7 @@ from fastfold.data import (
msa_pairing,
feature_processing_multimer,
)
from fastfold.data import templates
from fastfold.data.parsers import Msa
from fastfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from fastfold.data.tools.utils import to_date
......@@ -57,7 +58,7 @@ def empty_template_feats(n_res) -> FeatureDict:
def make_template_features(
input_sequence: str,
hits: Sequence[Any],
template_featurizer: Union[hhsearch.HHSearch, hmmsearch.Hmmsearch],
template_featurizer: Union[templates.TemplateHitFeaturizer, templates.HmmsearchHitFeaturizer],
query_pdb_code: Optional[str] = None,
query_release_date: Optional[str] = None,
) -> FeatureDict:
......@@ -65,7 +66,7 @@ def make_template_features(
if(len(hits_cat) == 0 or template_featurizer is None):
template_features = empty_template_feats(len(input_sequence))
else:
if type(template_featurizer) == hhsearch.HHSearch:
if type(template_featurizer) == templates.TemplateHitFeaturizer:
templates_result = template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=query_pdb_code,
......@@ -202,32 +203,35 @@ def make_pdb_features(
return pdb_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
)
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index]
)
species_ids.append(identifiers.species_id.encode('utf-8'))
num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
......@@ -235,9 +239,9 @@ def make_msa_features(
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
return features
def run_msa_tool(
msa_runner,
fasta_path: str,
......@@ -455,7 +459,7 @@ class AlignmentRunner:
class AlignmentRunnerMultimer(AlignmentRunner):
class AlignmentRunnerMultimer:
"""Runs alignment tools and saves the results"""
def __init__(
......@@ -504,7 +508,6 @@ class AlignmentRunnerMultimer(AlignmentRunner):
mgnify_max_hits:
Max number of mgnify hits
"""
# super().__init__()
db_map = {
"jackhmmer": {
"binary": jackhmmer_binary_path,
......@@ -810,43 +813,41 @@ class DataPipeline:
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
filename, ext = os.path.splitext(name)
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
msa = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
elif(ext == ".sto" and not "hmm_output" == filename):
msa = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
msa_data[name] =msa
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
filename, ext = os.path.splitext(f)
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa = parsers.parse_a3m(fp.read())
elif(ext == ".sto" and not "hmm_output" == filename):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
msa = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[f] = data
msa_data[f] = msa
return msa_data
......@@ -913,19 +914,13 @@ class DataPipeline:
must be provided.
"""
)
msa_data["dummy"] = {
"msa": [input_sequence],
"deletion_matrix": [[0 for _ in input_sequence]],
}
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
msa_data["dummy"] = Msa(
[input_sequence],
[[0 for _ in input_sequence]],
["dummy"]
)
msa_features = make_msa_features(list(msa_data.values()))
return msa_features
......@@ -996,7 +991,10 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -1014,13 +1012,24 @@ class DataPipeline:
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
"""
with open(pdb_path, 'r') as f:
pdb_str = f.read()
if(_structure_index is not None):
db_dir = os.path.dirname(pdb_path)
db = _structure_index["db"]
db_path = os.path.join(db_dir, db)
fp = open(db_path, "rb")
_, offset, length = _structure_index["files"][0]
fp.seek(offset)
pdb_str = fp.read(length).decode("utf-8")
fp.close()
else:
with open(pdb_path, 'r') as f:
pdb_str = f.read()
protein_object = protein.from_pdb_string(pdb_str, chain_id)
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
......@@ -1028,10 +1037,14 @@ class DataPipeline:
pdb_feats = make_pdb_features(
protein_object,
description,
is_distillation
is_distillation=is_distillation
)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
)
template_features = make_template_features(
input_sequence,
hits,
......@@ -1059,7 +1072,11 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(
alignment_dir,
input_sequence,
_alignment_index
)
template_features = make_template_features(
input_sequence,
hits,
......@@ -1123,8 +1140,8 @@ class DataPipelineMultimer:
uniprot_msa_path = os.path.join(alignment_dir, "uniprot_hits.sto")
with open(uniprot_msa_path, "r") as fp:
uniprot_msa_string = fp.read()
msa, deletion_matrix, _ = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features(msa, deletion_matrix)
msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features = make_msa_features([msa])
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
)
......
......@@ -76,8 +76,9 @@ def np_example_to_features(
mode: str,
):
np_example = dict(np_example)
print("np_example seq_length", np_example["seq_length"])
if is_multimer:
num_res = int(np_example["seq_length"])
num_res = int(np_example["seq_length"][0])
else:
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
......
......@@ -96,9 +96,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions
def parse_stockholm(
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
def parse_stockholm(stockholm_string: str) -> Msa:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
......@@ -153,10 +151,14 @@ def parse_stockholm(
deletion_count = 0
deletion_matrix.append(deletion_vec)
return msa, deletion_matrix, list(name_to_sequence.keys())
return Msa(
sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys())
)
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.
Args:
......@@ -171,7 +173,7 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
"""
sequences, _ = parse_fasta(a3m_string)
sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
......@@ -187,8 +189,12 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix
return Msa(
sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions
)
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
import re
from fastfold.np import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
)
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)'
tags = [
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if("[PRIMARY]" == g[0]):
seq = g[1][0].strip()
for i in range(len(seq)):
if(seq[i] not in residue_constants.restypes):
seq[i] = 'X'
aatype = np.array([
residue_constants.restype_order.get(
res_symbol, residue_constants.restype_num
) for res_symbol in seq
])
elif("[TERTIARY]" == g[0]):
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros(
(len(tertiary[0])//3, residue_constants.atom_type_num, 3)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (
np.transpose(tertiary_np[:, i::3])
)
atom_positions *= PICO_TO_ANGSTROM
elif("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(len(mask), residue_constants.atom_type_num,)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1")
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1]
)
)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
# Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors,
)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
fix_pdb uses a third-party tool. We also support fixing some additional edge
cases like removing chains of length one (see clean_structure).
"""
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info["missing_residues"] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
alterations_info["missing_terminals"] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(
fixer.topology, fixer.positions, out_handle, keepIds=True
)
return out_handle.getvalue()
def clean_structure(pdb_structure, alterations_info):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
def _remove_heterogens(fixer, alterations_info, keep_water):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info["removed_heterogens"] = initial_resnames.difference(
final_resnames
)
def _replace_met_se(pdb_structure, alterations_info):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == "MET":
s_atom = res.get_atom("SD")
if s_atom.element_symbol == "Se":
s_atom.element_symbol = "S"
s_atom.element = element.get_by_symbol("S")
modified_met_residues.append(s_atom.residue_number)
alterations_info["Se_in_MET"] = modified_met_residues
def _remove_chains_of_length_one(pdb_structure, alterations_info):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [
c.chain_id for c in model.iter_chains() if len(c) <= 1
]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info["removed_chains"] = removed_chains
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Amber relaxation."""
from typing import Any, Dict, Sequence, Tuple
from openfold.np import protein
from openfold.np.relax import amber_minimize, utils
import numpy as np
class AmberRelaxation(object):
"""Amber relaxation."""
def __init__(
self,
*,
max_iterations: int,
tolerance: float,
stiffness: float,
exclude_residues: Sequence[int],
max_outer_iterations: int,
use_gpu: bool,
):
"""Initialize Amber Relaxer.
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
use_gpu: Whether to run on GPU
"""
self._max_iterations = max_iterations
self._tolerance = tolerance
self._stiffness = stiffness
self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations
self._use_gpu = use_gpu
def process(
self, *, prot: protein.Protein
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot,
max_iterations=self._max_iterations,
tolerance=self._tolerance,
stiffness=self._stiffness,
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
)
min_pos = out["pos"]
start_pos = out["posinit"]
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
debug_data = {
"initial_energy": out["einit"],
"final_energy": out["efinal"],
"attempts": out["min_attempts"],
"rmsd": rmsd,
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
)
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
return min_pdb, debug_data, violations
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for minimization."""
import io
from openfold.np import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
Args:
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
Returns:
A new PDB string with the B-factors replaced.
"""
if bfactors.shape[-1] != residue_constants.atom_type_num:
raise ValueError(
f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
)
parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str)
structure = parser.get_structure("", handle)
curr_resid = ("", "", "")
idx = -1
for atom in structure.get_atoms():
atom_resid = atom.parent.get_id()
if atom_resid != curr_resid:
idx += 1
if idx >= bfactors.shape[0]:
raise ValueError(
"Index into bfactors exceeds number of residues. "
"B-factors shape: {shape}, idx: {idx}."
)
curr_resid = atom_resid
atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
new_pdb = io.StringIO()
pdb_io = PDB.PDBIO()
pdb_io.set_structure(structure)
pdb_io.save(new_pdb)
return new_pdb.getvalue()
def assert_equal_nonterminal_atom_types(
atom_mask: np.ndarray, ref_atom_mask: np.ndarray
):
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
)
This diff is collapsed.
......@@ -18,7 +18,12 @@ from typing import Dict, Text, Tuple
import torch
from fastfold.np import residue_constants as rc
from fastfold.common import residue_const
ants as rc
from fastfold.utils import geometry, tensor_utils
import numpy as np
......
......@@ -142,7 +142,6 @@ def main(args):
def inference_multimer_model(args):
print("running in multimer mode...")
config = model_config(args.model_name)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
predict_max_templates = 4
......@@ -235,11 +234,55 @@ def inference_multimer_model(args):
feature_dict = data_processor.process_fasta(
fasta_path=fasta_path, alignment_dir=local_alignment_dir
)
# feature_dict = pickle.load(open("/home/lcmql/data/features_pdb1o5d.pkl", "rb"))
processed_feature_dict = feature_processor.process_features(
feature_dict, mode='predict', is_multimer=True,
)
batch = processed_feature_dict
manager = mp.Manager()
result_q = manager.Queue()
torch.multiprocessing.spawn(inference_model, nprocs=args.gpus, args=(args.gpus, result_q, batch, args))
out = result_q.get()
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=True,
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
def inference_monomer_model(args):
print("running in monomer mode...")
config = model_config(args.model_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment