Unverified Commit ea7a6584 authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

Multimer (#55)

* support multimer datapipiline

* module for multimer

* support inference multimer

* add geometry

* delete unuseful code

* delete debug print
parent 2af21904
from typing import Sequence
import torch
from fastfold.data.data_transforms import curry1
from fastfold.utils.tensor_utils import masked_mean
def gumbel_noise(
shape: Sequence[int],
device: torch.device,
eps=1e-6,
generator=None,
) -> torch.Tensor:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise = torch.rand(
shape, dtype=torch.float32, device=device, generator=generator
)
gumbel = -torch.log(-torch.log(uniform_noise + eps) + eps)
return gumbel
def gumbel_max_sample(logits: torch.Tensor, generator=None) -> torch.Tensor:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.nn.functional.one_hot(
torch.argmax(logits + z, dim=-1),
logits.shape[-1],
)
def gumbel_argsort_sample_idx(
logits: torch.Tensor,
generator=None
) -> torch.Tensor:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.argsort(logits + z, dim=-1, descending=True)
@curry1
def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.Tensor(
[0.05] * 20 + [0., 0.],
device=batch['msa'].device
)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * batch['msa_profile'] +
config.same_prob * torch.nn.functional.one_hot(batch['msa'], 22)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
categorical_probs = torch.nn.functional.pad(
categorical_probs, [0,1], value=mask_prob
)
sh = batch['msa'].shape
mask_position = torch.rand(sh, device=batch['msa'].device) < replace_fraction
mask_position *= batch['msa_mask'].to(mask_position.dtype)
logits = torch.log(categorical_probs + eps)
g = torch.Generator(device=batch["msa"].device)
if seed is not None:
g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g)
bert_msa = torch.where(
mask_position,
torch.argmax(bert_msa, dim=-1),
batch['msa']
)
bert_msa *= batch['msa_mask'].to(bert_msa.dtype)
# Mix real and masked MSA.
if 'bert_mask' in batch:
batch['bert_mask'] *= mask_position.to(torch.float32)
else:
batch['bert_mask'] = mask_position.to(torch.float32)
batch['true_msa'] = batch['msa']
batch['msa'] = bert_msa
return batch
@curry1
def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device = batch["msa_mask"].device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights = torch.Tensor(
[1.] * 21 + [gap_agreement_weight] + [0.],
device=device,
)
msa_mask = batch['msa_mask']
msa_one_hot = torch.nn.functional.one_hot(batch['msa'], 23)
extra_mask = batch['extra_msa_mask']
extra_one_hot = torch.nn.functional.one_hot(batch['extra_msa'], 23)
msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
agreement = torch.einsum(
'mrc, nrc->nm',
extra_one_hot_masked,
weights * msa_one_hot_masked
)
cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0)
cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask)
cluster_count = torch.sum(cluster_assignment, dim=-1)
cluster_count += 1. # We always include the sequence itself.
msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
msa_sum += msa_one_hot_masked
cluster_profile = msa_sum / cluster_count[:, None, None]
extra_deletion_matrix = batch['extra_deletion_matrix']
deletion_matrix = batch['deletion_matrix']
del_sum = torch.einsum(
'nm, mc->nc',
cluster_assignment,
extra_mask * extra_deletion_matrix
)
del_sum += deletion_matrix # Original sequence.
cluster_deletion_mean = del_sum / cluster_count[:, None]
batch['cluster_profile'] = cluster_profile
batch['cluster_deletion_mean'] = cluster_deletion_mean
return batch
def create_target_feat(batch):
"""Create the target features"""
batch["target_feat"] = torch.nn.functional.one_hot(
batch["aatype"], 21
).to(torch.float32)
return batch
def create_msa_feat(batch):
"""Create and concatenate MSA features."""
device = batch["msa"]
msa_1hot = torch.nn.functional.one_hot(batch['msa'], 23)
deletion_matrix = batch['deletion_matrix']
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
deletion_mean_value = (
torch.atan(
batch['cluster_deletion_mean'] / 3.) *
(2. / pi)
)[..., None]
msa_feat = torch.cat(
[
msa_1hot,
has_deletion,
deletion_value,
batch['cluster_profile'],
deletion_mean_value
],
dim=-1,
)
batch["msa_feat"] = msa_feat
return batch
def build_extra_msa_feat(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa = batch['extra_msa']
deletion_matrix = batch['extra_deletion_matrix']
msa_1hot = torch.nn.functional.one_hot(extra_msa, 23)
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (
(torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
)
extra_msa_mask = batch['extra_msa_mask']
catted = torch.cat([msa_1hot, has_deletion, deletion_value], dim=-1)
return catted
@curry1
def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g = torch.Generator(device=batch["msa"].device)
if seed is not None:
g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position.
logits = (torch.clamp(torch.sum(batch['msa_mask'], dim=-1), 0., 1.) - 1.) * inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if 'cluster_bias_mask' not in batch:
cluster_bias_mask = torch.nn.functional.pad(
batch['msa'].new_zeros(batch['msa'].shape[0] - 1),
(1, 0),
value=1.
)
else:
cluster_bias_mask = batch['cluster_bias_mask']
logits += cluster_bias_mask * inf
index_order = gumbel_argsort_sample_idx(logits, generator=g)
sel_idx = index_order[:max_seq]
extra_idx = index_order[max_seq:][:max_extra_msa_seq]
for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
if k in batch:
batch['extra_' + k] = batch[k][extra_idx]
batch[k] = batch[k][sel_idx]
return batch
def make_msa_profile(batch):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch["msa_profile"] = masked_mean(
batch['msa_mask'][..., None],
torch.nn.functional.one_hot(batch['msa'], 22),
dim=-3,
)
return batch
......@@ -20,7 +20,7 @@ import ml_collections
import numpy as np
import torch
from fastfold.data import input_pipeline
from fastfold.data import input_pipeline, input_pipeline_multimer
FeatureDict = Mapping[str, np.ndarray]
......@@ -72,10 +72,14 @@ def make_data_config(
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
is_multimer: bool,
mode: str,
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
if is_multimer:
num_res = int(np_example["seq_length"])
else:
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:
......@@ -86,12 +90,20 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if is_multimer:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
else:
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
return {k: v for k, v in features.items()}
......@@ -107,9 +119,11 @@ class FeaturePipeline:
self,
raw_features: FeatureDict,
mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
mode=mode,
is_multimer=is_multimer,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
from fastfold.data import (
data_transforms,
data_transforms_multimer,
)
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks,
]
if(common_cfg.use_templates):
transforms.extend([
data_transforms.make_pseudo_beta("template_"),
])
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms_multimer.sample_msa(
max_msa_clusters,
max_extra_msa,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms_multimer.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat)
return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""
import dataclasses
import re
from typing import Optional
# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE)
@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with a uniprot_accession_id and species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(
species_id=matches.group('SpeciesIdentifier')
)
return Identifiers()
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data pipeline."""
import collections
import functools
import string
from typing import Any, Dict, Iterable, List, Sequence, Mapping
import numpy as np
import pandas as pd
import scipy.linalg
from openfold.np import residue_constants
# TODO: This stuff should probably also be in a config
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0}
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
'all_atom_mask', 'seq_mask', 'between_segment_residues',
'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
'sym_id', 'entity_mask', 'deletion_mean',
'prediction_atom_mask',
'literature_positions', 'atom_indices_to_group_indices',
'rigid_group_default_frame')
TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
'template_all_atom_mask')
CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features(
chains: Iterable[Mapping[str, np.ndarray]],
) -> List[Mapping[str, np.ndarray]]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in ('msa_all_seq', 'msa_mask_all_seq',
'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded
def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_row':
np.arange(len(
chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity,
'gap': per_seq_gap
})
return msa_df
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[species] = species_df
return species_lookup
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []
num_seqs = [len(species_df) for species_df in this_species_msa_dfs
if species_df is not None]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = (
lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows
def pair_sequences(
examples: List[Mapping[str, np.ndarray]],
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))
common_species = sorted(common_species)
common_species.remove(b'') # Remove target sequence species.
all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)
# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue
if np.any(
np.array([len(species_df) for species_df in
this_species_msa_dfs if
isinstance(species_df, pd.DataFrame)]) > 600):
continue
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows) for
num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
return np.array(all_paired_msa_rows)
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def _correct_post_merged_feats(
np_example: Mapping[str, np.ndarray],
np_chains_list: Sequence[Mapping[str, np.ndarray]],
pair_msa_sequences: bool
) -> Mapping[str, np.ndarray]:
"""Adds features that need to be computed/recomputed post merging."""
num_res = np_example['aatype'].shape[0]
np_example['seq_length'] = np.asarray(
[num_res] * num_res,
dtype=np.int32
)
np_example['num_alignments'] = np.asarray(
np_example['msa'].shape[0],
dtype=np.int32
)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list
]
np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0
)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.float32) for
x in np_chains_list
]
msa_masks_all_seq = [
np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list
]
msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0
)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag],
axis=0
)
return np_example
def _pad_templates(chains: Sequence[Mapping[str, np.ndarray]],
max_templates: int) -> Sequence[Mapping[str, np.ndarray]]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode='constant')
return chains
def _merge_features_from_multiple_chains(
chains: Sequence[Mapping[str, np.ndarray]],
pair_msa_sequences: bool) -> Mapping[str, np.ndarray]:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split('_all_seq')[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or '_all_seq' in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example
def _merge_homomers_dense_msa(
chains: Iterable[Mapping[str, np.ndarray]]) -> Sequence[Mapping[str, np.ndarray]]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain['entity_id'][0]
entity_chains[entity_id].append(chain)
grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains]
return chains
def _concatenate_paired_and_unpaired_features(
example: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(example['msa'].shape[0],
dtype=np.int32)
return example
def merge_chain_features(np_chains_list: List[Mapping[str, np.ndarray]],
pair_msa_sequences: bool,
max_templates: int) -> Mapping[str, np.ndarray]:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences)
return np_example
def deduplicate_unpaired_sequences(
np_chains: List[Mapping[str, np.ndarray]]) -> List[Mapping[str, np.ndarray]]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
for chain in np_chains:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa']):
if tuple(seq) not in sequence_set:
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
chain[feature_name] = chain[feature_name][keep_rows]
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
......@@ -26,12 +26,12 @@ from fastfold.utils.feats import (
)
from fastfold.model.nn.embedders import (
InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder,
)
from fastfold.model.nn.embedders_multimer import TemplateEmbedderMultimer, InputEmbedderMultimer
from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants
......@@ -74,25 +74,31 @@ class AlphaFold(nn.Module):
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
)
self.template_embedder = TemplateEmbedderMultimer(
template_config,
)
else:
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
self.template_angle_embedder = TemplateAngleEmbedder(
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
)
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
......@@ -103,6 +109,7 @@ class AlphaFold(nn.Module):
**config["evoformer_stack"],
)
self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**config["structure_module"],
)
......
......@@ -125,146 +125,6 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
......
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, Dict
from fastfold.utils import all_atom_multimer
from fastfold.utils.feats import dgram_from_positions
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.utils import geometry
from fastfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if self.use_chain_relative:
self.no_bins = 2 * max_relative_idx + 2 + 1 + 2 * max_relative_chain + 2
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch: Dict[str, torch.Tensor]):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = asym_id[..., None] == asym_id[..., None, :]
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if self.use_chain_relative:
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) * torch.ones_like(clipped_offset),
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = entity_id[..., None] == entity_id[..., None, :]
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) * torch.ones_like(clipped_rel_chain),
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain.long(),
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset,
2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(self.linear_relpos.weight.dtype)
return self.linear_relpos(rel_feat)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super().__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super().__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
)
template_activations = self.template_projector(
template_activations,
)
out["template_single_embedding"] = (
template_activations
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_
from fastfold.common.residue_constants import (
......@@ -25,6 +25,9 @@ from fastfold.common.residue_constants import (
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from fastfold.utils.geometry.quat_rigid import QuatRigid
from fastfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from fastfold.utils.geometry.vector import Vec3Array
from fastfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
......@@ -150,11 +153,47 @@ class AngleResnet(nn.Module):
return unnormalized_s, s
class PointProjection(nn.Module):
def __init__(
self,
c_hidden: int,
num_points: int,
no_heads: int,
return_local_points: bool = False,
):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.linear = Linear(c_hidden, no_heads * 3 * num_points)
def forward(
self,
activations: torch.Tensor,
rigids: Rigid3Array,
) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
points_local = points_local.reshape(
*points_local.shape[:-1],
self.no_heads,
-1,
)
points_local = torch.split(points_local, points_local.shape[-1] // 3, dim=-1)
points_local = Vec3Array(*points_local)
points_global = rigids[..., None, None].apply_to_point(points_local)
if self.return_local_points:
return points_global, points_local
return points_global
class InvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(
self,
c_s: int,
......@@ -165,6 +204,7 @@ class InvariantPointAttention(nn.Module):
no_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
is_multimer: bool = False,
):
"""
Args:
......@@ -191,23 +231,45 @@ class InvariantPointAttention(nn.Module):
self.no_v_points = no_v_points
self.inf = inf
self.eps = eps
self.is_multimer = is_multimer
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc)
self.linear_kv = Linear(self.c_s, 2 * hc)
if not self.is_multimer:
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer))
self.linear_kv = Linear(self.c_s, 2 * hc)
hpq = self.no_heads * self.no_qk_points * 3
self.linear_q_points = Linear(self.c_s, hpq)
hpq = self.no_heads * self.no_qk_points * 3
self.linear_q_points = Linear(self.c_s, hpq)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
self.linear_kv_points = Linear(self.c_s, hpkv)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
self.linear_kv_points = Linear(self.c_s, hpkv)
# hpv = self.no_heads * self.no_v_points * 3
hpv = self.no_heads * self.no_v_points * 3
else:
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer))
self.linear_q_points = PointProjection(
self.c_s, self.no_qk_points, self.no_heads
)
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection(
self.c_s,
self.no_qk_points,
self.no_heads,
)
self.linear_v_points = PointProjection(
self.c_s,
self.no_v_points,
self.no_heads,
)
self.linear_b = Linear(self.c_z, self.no_heads)
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
......@@ -225,7 +287,7 @@ class InvariantPointAttention(nn.Module):
self,
s: torch.Tensor,
z: torch.Tensor,
r: Rigid,
r: Union[Rigid, Rigid3Array],
mask: torch.Tensor,
) -> torch.Tensor:
"""
......@@ -244,48 +306,72 @@ class InvariantPointAttention(nn.Module):
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
if self.is_multimer:
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H, P_qk]
q_pts = self.linear_q_points(s, r)
# [*, N_res, H * C_hidden]
k = self.linear_k(s)
v = self.linear_v(s)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# [*, N_res, H, C_hidden]
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_qk, 3]
k_pts = self.linear_k_points(s, r)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H, P_v, 3]
v_pts = self.linear_v_points(s, r)
else:
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3))
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
##########################
# Compute attention scores
......@@ -299,14 +385,20 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
if self.is_multimer:
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
# [*, N_res, N_res, H, P_q]
pt_att = sum([c**2 for c in pt_att])
else:
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att**2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
......@@ -323,7 +415,7 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
......@@ -331,35 +423,47 @@ class InvariantPointAttention(nn.Module):
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(
a, v.transpose(-2, -3).to(dtype=a.dtype)
).transpose(-2, -3)
o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
if self.is_multimer:
# [*, N_res, H, P_v]
o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v]
o_pt = r[..., None, None].apply_inverse_to_point(o_pt)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(self.eps)
else:
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
......@@ -368,11 +472,16 @@ class InvariantPointAttention(nn.Module):
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
)
if self.is_multimer:
s = self.linear_out(
torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype)
)
else:
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
)
return s
......@@ -476,6 +585,7 @@ class StructureModule(nn.Module):
trans_scale_factor,
epsilon,
inf,
is_multimer=False,
**kwargs,
):
"""
......@@ -529,6 +639,7 @@ class StructureModule(nn.Module):
self.trans_scale_factor = trans_scale_factor
self.epsilon = epsilon
self.inf = inf
self.is_multimer = is_multimer
# To be lazily initialized later
self.default_frames = None
......@@ -550,6 +661,7 @@ class StructureModule(nn.Module):
self.no_v_points,
inf=self.inf,
eps=self.epsilon,
is_multimer=self.is_multimer,
)
self.ipa_dropout = nn.Dropout(self.dropout_rate)
......
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
import re
from fastfold.np import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
PICO_TO_ANGSTROM = 0.01
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)
assert(PDB_MAX_CHAINS == 62)
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this
# residue belongs to
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format"
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain is
parsed. Else, all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if(chain_id is not None and chain.id != chain_id):
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[
residue_constants.atom_order[atom.name]
] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
)
def from_proteinnet_string(proteinnet_str: str) -> Protein:
tag_re = r'(\[[A-Z]+\]\n)'
tags = [
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
]
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
atoms = ['N', 'CA', 'C']
aatype = None
atom_positions = None
atom_mask = None
for g in groups:
if("[PRIMARY]" == g[0]):
seq = g[1][0].strip()
for i in range(len(seq)):
if(seq[i] not in residue_constants.restypes):
seq[i] = 'X'
aatype = np.array([
residue_constants.restype_order.get(
res_symbol, residue_constants.restype_num
) for res_symbol in seq
])
elif("[TERTIARY]" == g[0]):
tertiary = []
for axis in range(3):
tertiary.append(list(map(float, g[1][axis].split())))
tertiary_np = np.array(tertiary)
atom_positions = np.zeros(
(len(tertiary[0])//3, residue_constants.atom_type_num, 3)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_positions[:, residue_constants.atom_order[atom], :] = (
np.transpose(tertiary_np[:, i::3])
)
atom_positions *= PICO_TO_ANGSTROM
elif("[MASK]" == g[0]):
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
atom_mask = np.zeros(
(len(mask), residue_constants.atom_type_num,)
).astype(np.float32)
for i, atom in enumerate(atoms):
atom_mask[:, residue_constants.atom_order[atom]] = 1
atom_mask *= mask[..., None]
return Protein(
atom_positions=atom_positions,
atom_mask=atom_mask,
aatype=aatype,
residue_index=np.arange(len(aatype)),
b_factors=None,
)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}'
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1")
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1]
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[
0
] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1]
)
)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
# Pad all lines to 80 characters
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None,
remove_leading_feature_dimension: bool = True,
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
result: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
remove_leading_feature_dimension: Whether to remove the leading dimension
of the `features` values
Returns:
A protein instance.
"""
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
return arr[0] if remove_leading_feature_dimension else arr
if 'asym_id' in features:
chain_index = _maybe_remove_leading_dim(features["asym_id"])
else:
chain_index = np.zeros_like(
_maybe_remove_leading_dim(features["aatype"])
)
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=_maybe_remove_leading_dim(features["aatype"]),
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1,
chain_index=chain_index,
b_factors=b_factors,
)
import os
import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
fix_pdb uses a third-party tool. We also support fixing some additional edge
cases like removing chains of length one (see clean_structure).
"""
import io
import pdbfixer
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info):
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
1) Replaces nonstandard residues.
2) Removes heterogens (non protein residues) including water.
3) Adds missing residues and missing atoms within existing residues.
4) Adds hydrogens assuming pH=7.0.
5) KeepIds is currently true, so the fixer must keep the existing chain and
residue identifiers. This will fail for some files in wider PDB that have
invalid IDs.
Args:
pdbfile: Input PDB file handle.
alterations_info: A dict that will store details of changes made.
Returns:
A PDB string representing the fixed structure.
"""
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
fixer.findNonstandardResidues()
alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
fixer.replaceNonstandardResidues()
_remove_heterogens(fixer, alterations_info, keep_water=False)
fixer.findMissingResidues()
alterations_info["missing_residues"] = fixer.missingResidues
fixer.findMissingAtoms()
alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
alterations_info["missing_terminals"] = fixer.missingTerminals
fixer.addMissingAtoms(seed=0)
fixer.addMissingHydrogens()
out_handle = io.StringIO()
app.PDBFile.writeFile(
fixer.topology, fixer.positions, out_handle, keepIds=True
)
return out_handle.getvalue()
def clean_structure(pdb_structure, alterations_info):
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
Args:
pdb_structure: An OpenMM structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
_replace_met_se(pdb_structure, alterations_info)
_remove_chains_of_length_one(pdb_structure, alterations_info)
def _remove_heterogens(fixer, alterations_info, keep_water):
"""Removes the residues that Pdbfixer considers to be heterogens.
Args:
fixer: A Pdbfixer instance.
alterations_info: A dict that will store details of changes made.
keep_water: If True, water (HOH) is not considered to be a heterogen.
"""
initial_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
initial_resnames.add(residue.name)
fixer.removeHeterogens(keepWater=keep_water)
final_resnames = set()
for chain in fixer.topology.chains():
for residue in chain.residues():
final_resnames.add(residue.name)
alterations_info["removed_heterogens"] = initial_resnames.difference(
final_resnames
)
def _replace_met_se(pdb_structure, alterations_info):
"""Replace the Se in any MET residues that were not marked as modified."""
modified_met_residues = []
for res in pdb_structure.iter_residues():
name = res.get_name_with_spaces().strip()
if name == "MET":
s_atom = res.get_atom("SD")
if s_atom.element_symbol == "Se":
s_atom.element_symbol = "S"
s_atom.element = element.get_by_symbol("S")
modified_met_residues.append(s_atom.residue_number)
alterations_info["Se_in_MET"] = modified_met_residues
def _remove_chains_of_length_one(pdb_structure, alterations_info):
"""Removes chains that correspond to a single amino acid.
A single amino acid in a chain is both N and C terminus. There is no force
template for this case.
Args:
pdb_structure: An OpenMM pdb_structure to modify and fix.
alterations_info: A dict that will store details of changes made.
"""
removed_chains = {}
for model in pdb_structure.iter_models():
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
invalid_chain_ids = [
c.chain_id for c in model.iter_chains() if len(c) <= 1
]
model.chains = valid_chains
for chain_id in invalid_chain_ids:
model.chains_by_id.pop(chain_id)
removed_chains[model.number] = invalid_chain_ids
alterations_info["removed_chains"] = removed_chains
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Amber relaxation."""
from typing import Any, Dict, Sequence, Tuple
from openfold.np import protein
from openfold.np.relax import amber_minimize, utils
import numpy as np
class AmberRelaxation(object):
"""Amber relaxation."""
def __init__(
self,
*,
max_iterations: int,
tolerance: float,
stiffness: float,
exclude_residues: Sequence[int],
max_outer_iterations: int,
use_gpu: bool,
):
"""Initialize Amber Relaxer.
Args:
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
tolerance: kcal/mol, the energy tolerance of L-BFGS.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
exclude_residues: Residues to exclude from per-atom restraining.
Zero-indexed.
max_outer_iterations: Maximum number of violation-informed relax
iterations. A value of 1 will run the non-iterative procedure used in
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
as soon as there are no violations, hence in most cases this causes no
slowdown. In the worst case we do 20 outer iterations.
use_gpu: Whether to run on GPU
"""
self._max_iterations = max_iterations
self._tolerance = tolerance
self._stiffness = stiffness
self._exclude_residues = exclude_residues
self._max_outer_iterations = max_outer_iterations
self._use_gpu = use_gpu
def process(
self, *, prot: protein.Protein
) -> Tuple[str, Dict[str, Any], np.ndarray]:
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
out = amber_minimize.run_pipeline(
prot=prot,
max_iterations=self._max_iterations,
tolerance=self._tolerance,
stiffness=self._stiffness,
exclude_residues=self._exclude_residues,
max_outer_iterations=self._max_outer_iterations,
use_gpu=self._use_gpu,
)
min_pos = out["pos"]
start_pos = out["posinit"]
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
debug_data = {
"initial_energy": out["einit"],
"final_energy": out["efinal"],
"attempts": out["min_attempts"],
"rmsd": rmsd,
}
pdb_str = amber_minimize.clean_protein(prot)
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
utils.assert_equal_nonterminal_atom_types(
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
)
violations = out["structural_violations"][
"total_per_residue_violations_mask"
]
return min_pdb, debug_data, violations
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for minimization."""
import io
from openfold.np import residue_constants
from Bio import PDB
import numpy as np
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
pdb_file = io.StringIO(pdb_str)
structure = PdbStructure(pdb_file)
topology = openmm_app.PDBFile(structure).getTopology()
with io.StringIO() as f:
openmm_app.PDBFile.writeFile(topology, pos, f)
return f.getvalue()
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
Args:
pdb_str: An input PDB string.
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
B-factors are per residue; i.e. that the nonzero entries are identical in
[0, i, :].
Returns:
A new PDB string with the B-factors replaced.
"""
if bfactors.shape[-1] != residue_constants.atom_type_num:
raise ValueError(
f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
)
parser = PDB.PDBParser(QUIET=True)
handle = io.StringIO(pdb_str)
structure = parser.get_structure("", handle)
curr_resid = ("", "", "")
idx = -1
for atom in structure.get_atoms():
atom_resid = atom.parent.get_id()
if atom_resid != curr_resid:
idx += 1
if idx >= bfactors.shape[0]:
raise ValueError(
"Index into bfactors exceeds number of residues. "
"B-factors shape: {shape}, idx: {idx}."
)
curr_resid = atom_resid
atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
new_pdb = io.StringIO()
pdb_io = PDB.PDBIO()
pdb_io.set_structure(structure)
pdb_io.save(new_pdb)
return new_pdb.getvalue()
def assert_equal_nonterminal_atom_types(
atom_mask: np.ndarray, ref_atom_mask: np.ndarray
):
"""Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
)
This diff is collapsed.
This diff is collapsed.
......@@ -30,6 +30,22 @@ from fastfold.utils.tensor_utils import (
tensor_tree_map,
)
def dgram_from_positions(
pos: torch.Tensor,
min_bin: float = 3.25,
max_bin: float = 50.75,
no_bins: float = 39,
inf: float = 1e8,
):
dgram = torch.sum(
(pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
)
lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
return dgram
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order["G"]
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from fastfold.utils.geometry import rigid_matrix_vector
from fastfold.utils.geometry import rotation_matrix
from fastfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment