Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data """
import collections
from typing import Dict, Iterable, List, Sequence
from .residue_constants import restypes_with_x_and_gap
from .data_ops import NumpyDict
import numpy as np
import pandas as pd
import scipy.linalg
MSA_GAP_IDX = restypes_with_x_and_gap.index("-")
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9
MSA_PAD_VALUES = {
"msa_all_seq": MSA_GAP_IDX,
"msa_mask_all_seq": 1,
"deletion_matrix_all_seq": 0,
"deletion_matrix_int_all_seq": 0,
"msa": MSA_GAP_IDX,
"msa_mask": 1,
"deletion_matrix": 0,
"deletion_matrix_int": 0,
}
MSA_FEATURES = ("msa", "msa_mask", "deletion_matrix", "deletion_matrix_int")
SEQ_FEATURES = (
"residue_index",
"aatype",
"all_atom_positions",
"all_atom_mask",
"seq_mask",
"between_segment_residues",
"has_alt_locations",
"has_hetatoms",
"asym_id",
"entity_id",
"sym_id",
"entity_mask",
"deletion_mean",
"prediction_atom_mask",
"literature_positions",
"atom_indices_to_group_indices",
"rigid_group_default_frame",
# zy
"num_sym",
)
TEMPLATE_FEATURES = (
"template_aatype",
"template_all_atom_positions",
"template_all_atom_mask",
)
CHAIN_FEATURES = ("num_alignments", "seq_length")
def create_paired_features(
chains: Iterable[NumpyDict],
) -> List[NumpyDict]:
"""Returns the original chains with paired NUM_SEQ features.
Args:
chains: A list of feature dictionaries for each chain.
Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()
if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices)
for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if "_all_seq" not in k}
for feature_name in chain_keys:
if feature_name.endswith("_all_seq"):
feats_padded = pad_features(chain[feature_name], feature_name)
new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
new_chain["num_alignments_all_seq"] = np.asarray(
len(paired_rows[:, chain_num])
)
updated_chains.append(new_chain)
return updated_chains
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.
The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.
Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.
Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in (
"msa_all_seq",
"msa_mask_all_seq",
"deletion_matrix_all_seq",
"deletion_matrix_int_all_seq",
):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], feature.dtype)
elif feature_name == "msa_species_identifiers_all_seq":
padding = [b""]
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded
def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features["msa_all_seq"]
query_seq = chain_msa[0]
per_seq_similarity = np.sum(query_seq[None] == chain_msa, axis=-1) / float(
len(query_seq)
)
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame(
{
"msa_species_identifiers": chain_features[
"msa_species_identifiers_all_seq"
],
"msa_row": np.arange(
len(chain_features["msa_species_identifiers_all_seq"])
),
"msa_similarity": per_seq_similarity,
"gap": per_seq_gap,
}
)
return msa_df
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby("msa_species_identifiers"):
species_lookup[species] = species_df
return species_lookup
def _match_rows_by_sequence_similarity(
this_species_msa_dfs: List[pd.DataFrame],
) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity.
Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []
num_seqs = [
len(species_df) for species_df in this_species_msa_dfs if species_df is not None
]
take_num_seqs = np.min(num_seqs)
sort_by_similarity = lambda x: x.sort_values(
"msa_similarity", axis=0, ascending=False
)
for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows
def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""
num_examples = len(examples)
all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))
common_species = sorted(common_species)
common_species.remove(b"") # Remove target sequence species.
all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)
# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue
if np.any(
np.array(
[
len(species_df)
for species_df in this_species_msa_dfs
if isinstance(species_df, pd.DataFrame)
]
)
> 600
):
continue
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows)
for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict
def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.
Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.
Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = np.abs(
np.array([np.prod(rows.astype(np.float64)) for rows in paired_rows])
)
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
return np.array(all_paired_msa_rows)
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
def _correct_post_merged_feats(
np_example: NumpyDict, np_chains_list: Sequence[NumpyDict], pair_msa_sequences: bool
) -> NumpyDict:
"""Adds features that need to be computed/recomputed post merging."""
np_example["seq_length"] = np.asarray(np_example["aatype"].shape[0], dtype=np.int32)
np_example["num_alignments"] = np.asarray(
np_example["msa"].shape[0], dtype=np.int32
)
if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain["msa"].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example["cluster_bias_mask"] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x["msa"].shape, dtype=np.int8) for x in np_chains_list]
np_example["bert_mask"] = block_diag(*msa_masks, pad_value=0)
else:
np_example["cluster_bias_mask"] = np.zeros(np_example["msa"].shape[0])
np_example["cluster_bias_mask"][0] = 1
# Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x["msa"].shape, dtype=np.int8) for x in np_chains_list]
msa_masks_all_seq = [
np.ones(x["msa_all_seq"].shape, dtype=np.int8) for x in np_chains_list
]
msa_mask_block_diag = block_diag(*msa_masks, pad_value=0)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example["bert_mask"] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0
)
return np_example
def _pad_templates(
chains: Sequence[NumpyDict], max_templates: int
) -> Sequence[NumpyDict]:
"""For each chain pad the number of templates to a fixed size.
Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.
Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode="constant")
return chains
def _merge_features_from_multiple_chains(
chains: Sequence[NumpyDict], pair_msa_sequences: bool
) -> NumpyDict:
"""Merge features from multiple chains.
Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).
Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split("_all_seq")[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or "_all_seq" in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
if feature_name_split == "msa":
merged_example["msa_chains_all_seq"] = np.ones(
merged_example[feature_name].shape[0]
).reshape(-1, 1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name]
)
if feature_name_split == "msa":
msa_chains = []
for i, feat in enumerate(feats):
cur_shape = feat.shape[0]
vals = np.ones(cur_shape) * (i + 2)
msa_chains.append(vals)
merged_example["msa_chains"] = np.concatenate(msa_chains).reshape(
-1, 1
)
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example
def _merge_homomers_dense_msa(chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]:
"""Merge all identical chains, making the resulting MSA dense.
Args:
chains: An iterable of features for each chain.
Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain["entity_id"][0]
entity_chains[entity_id].append(chain)
grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains
]
return chains
def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES + ("msa_chains",)
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + "_all_seq"]
try:
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
except Exception as ex:
raise Exception(
"concat failed.",
feature_name,
feat_all_seq.shape,
feat.shape,
ex.__class__,
ex,
)
example[feature_name] = merged_feat
example["num_alignments"] = np.array(example["msa"].shape[0], dtype=np.int32)
return example
def merge_chain_features(
np_chains_list: List[NumpyDict], pair_msa_sequences: bool, max_templates: int
) -> NumpyDict:
"""Merges features for multiple chains to single FeatureDict.
Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.
Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False
)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences,
)
return np_example
def deduplicate_unpaired_sequences(np_chains: List[NumpyDict]) -> List[NumpyDict]:
"""Removes unpaired sequences which duplicate a paired sequence."""
feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
cache_msa_features = {}
for chain in np_chains:
entity_id = int(chain["entity_id"][0])
if entity_id not in cache_msa_features:
sequence_set = set(s.tobytes() for s in chain["msa_all_seq"])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain["msa"]):
if seq.tobytes() not in sequence_set:
keep_rows.append(row_num)
new_msa_features = {}
for feature_name in feature_names:
if feature_name in msa_features:
if keep_rows:
new_msa_features[feature_name] = chain[feature_name][keep_rows]
else:
new_shape = list(chain[feature_name].shape)
new_shape[0] = 0
new_msa_features[feature_name] = np.zeros(
new_shape, dtype=chain[feature_name].dtype
)
cache_msa_features[entity_id] = new_msa_features
for feature_name in cache_msa_features[entity_id]:
chain[feature_name] = cache_msa_features[entity_id][feature_name]
chain["num_alignments"] = np.array(chain["msa"].shape[0], dtype=np.int32)
return np_chains
from typing import Optional
import torch
import numpy as np
from unifold.data import data_ops
def nonensembled_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
v2_feature = common_cfg.v2_feature
operators = []
if mode_cfg.random_delete_msa:
operators.append(data_ops.random_delete_msa(common_cfg.random_delete_msa))
operators.extend(
[
data_ops.cast_to_64bit_ints,
data_ops.correct_msa_restypes,
data_ops.squeeze_features,
data_ops.randomly_replace_msa_with_unknown(0.0),
data_ops.make_seq_mask,
data_ops.make_msa_mask,
]
)
operators.append(
data_ops.make_hhblits_profile_v2 if v2_feature else data_ops.make_hhblits_profile
)
if common_cfg.use_templates:
operators.extend(
[
data_ops.make_template_mask,
data_ops.make_pseudo_beta("template_"),
]
)
operators.append(
data_ops.crop_templates(
max_templates=mode_cfg.max_templates,
subsample_templates=mode_cfg.subsample_templates,
)
)
if common_cfg.use_template_torsion_angles:
operators.extend(
[
data_ops.atom37_to_torsion_angles("template_"),
]
)
operators.append(data_ops.make_atom14_masks)
operators.append(data_ops.make_target_feat)
return operators
def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed):
operators = []
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = mode_cfg.max_msa_clusters
crop_feats = dict(common_cfg.features)
if mode_cfg.fixed_size:
if mode_cfg.crop:
if common_cfg.is_multimer:
crop_fn = data_ops.crop_to_size_multimer(
crop_size=mode_cfg.crop_size,
shape_schema=crop_feats,
seed=crop_and_fix_size_seed,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
ca_ca_threshold=mode_cfg.ca_ca_threshold,
)
else:
crop_fn = data_ops.crop_to_size_single(
crop_size=mode_cfg.crop_size,
shape_schema=crop_feats,
seed=crop_and_fix_size_seed,
)
operators.append(crop_fn)
operators.append(data_ops.select_feat(crop_feats))
operators.append(
data_ops.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
)
return operators
def ensembled_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that can be ensembled and averaged."""
operators = []
multimer_mode = common_cfg.is_multimer
v2_feature = common_cfg.v2_feature
# multimer don't use block delete msa
if mode_cfg.block_delete_msa and not multimer_mode:
operators.append(data_ops.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg:
operators.append(
data_ops.sample_msa_distillation(mode_cfg.max_distillation_msa_clusters)
)
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
assert common_cfg.resample_msa_in_recycling
gumbel_sample = common_cfg.gumbel_sample
operators.append(
data_ops.sample_msa(
max_msa_clusters,
keep_extra=True,
gumbel_sample=gumbel_sample,
biased_msa_by_chain=mode_cfg.biased_msa_by_chain,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
operators.append(
data_ops.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
gumbel_sample=gumbel_sample,
share_mask=mode_cfg.share_mask,
)
)
if common_cfg.msa_cluster_features:
if v2_feature:
operators.append(data_ops.nearest_neighbor_clusters_v2())
else:
operators.append(data_ops.nearest_neighbor_clusters())
operators.append(data_ops.summarize_clusters)
if v2_feature:
operators.append(data_ops.make_msa_feat_v2)
else:
operators.append(data_ops.make_msa_feat)
# Crop after creating the cluster profiles.
if max_extra_msa:
if v2_feature:
operators.append(data_ops.make_extra_msa_feat(max_extra_msa))
else:
operators.append(data_ops.crop_extra_msa(max_extra_msa))
else:
operators.append(data_ops.delete_extra_msa)
# operators.append(data_operators.select_feat(common_cfg.recycling_features))
return operators
def process_features(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
is_distillation = bool(tensors.get("is_distillation", 0))
multimer_mode = common_cfg.is_multimer
crop_and_fix_size_seed = int(tensors["crop_and_fix_size_seed"])
crop_fn = crop_and_fix_size_fns(
common_cfg,
mode_cfg,
crop_and_fix_size_seed,
)
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_fns(
common_cfg,
mode_cfg,
)
new_d = compose(fns)(d)
if not multimer_mode or is_distillation:
new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d)
return compose(crop_fn)(new_d)
else: # select after crop for spatial cropping
d = compose(crop_fn)(d)
d = data_ops.select_feat(common_cfg.recycling_features)(d)
return d
nonensembled = nonensembled_fns(common_cfg, mode_cfg)
if mode_cfg.supervised and (not multimer_mode or is_distillation):
nonensembled.extend(label_transform_fn())
tensors = compose(nonensembled)(tensors)
num_recycling = int(tensors["num_recycling_iters"]) + 1
num_ensembles = mode_cfg.num_ensembles
ensemble_tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x),
torch.arange(num_recycling * num_ensembles),
)
tensors = compose(crop_fn)(tensors)
# add a dummy dim to align with recycling features
tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors}
tensors.update(ensemble_tensors)
return tensors
@data_ops.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def pad_then_stack(
values,
):
if len(values[0].shape) >= 1:
size = max(v.shape[0] for v in values)
new_values = []
for v in values:
if v.shape[0] < size:
res = values[0].new_zeros(size, *v.shape[1:])
res[:v.shape[0], ...] = v
else:
res = v
new_values.append(res)
else:
new_values = values
return torch.stack(new_values, dim=0)
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = pad_then_stack(
[dict_i[feat] for dict_i in ensembles]
)
return ensembled_dict
def process_single_label(label: dict, num_ensemble: Optional[int] = None) -> dict:
assert "aatype" in label
assert "all_atom_positions" in label
assert "all_atom_mask" in label
label = compose(label_transform_fn())(label)
if num_ensemble is not None:
label = {
k: torch.stack([v for _ in range(num_ensemble)]) for k, v in label.items()
}
return label
def process_labels(labels_list, num_ensemble: Optional[int] = None):
return [process_single_label(l, num_ensemble) for l in labels_list]
def label_transform_fn():
return [
data_ops.make_atom14_masks,
data_ops.make_atom14_positions,
data_ops.atom37_to_frames,
data_ops.atom37_to_torsion_angles(""),
data_ops.make_pseudo_beta(""),
data_ops.get_backbone_frames,
data_ops.get_chi_angles,
]
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data """
from typing import Iterable, MutableMapping, List
import collections
from unifold.data import residue_constants, msa_pairing
import numpy as np
from .utils import correct_template_restypes
FeatureDict = MutableMapping[str, np.ndarray]
REQUIRED_FEATURES = frozenset(
{
"aatype",
"all_atom_mask",
"all_atom_positions",
"all_chains_entity_ids",
"all_crops_all_chains_mask",
"all_crops_all_chains_positions",
"all_crops_all_chains_residue_ids",
"assembly_num_chains",
"asym_id",
"bert_mask",
"cluster_bias_mask",
"deletion_matrix",
"deletion_mean",
"entity_id",
"entity_mask",
"mem_peak",
"msa",
"msa_mask",
"num_alignments",
"num_templates",
"queue_size",
"residue_index",
"resolution",
"seq_length",
"seq_mask",
"sym_id",
"template_aatype",
"template_all_atom_mask",
"template_all_atom_positions",
# zy added:
"asym_len",
"template_sum_probs",
"num_sym",
"msa_chains",
}
)
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(
np.unique(
np.concatenate(
[
np.unique(chain["entity_id"][chain["entity_id"] > 0])
for chain in chains
]
)
)
)
return num_unique_chains == 1
def pair_and_merge(all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict:
"""Runs processing on features to augment, pair and merge.
Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.
Returns:
A dictionary of features.
"""
process_unmerged_features(all_chain_features)
np_chains_list = all_chain_features
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES,
)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES,
)
np_example = process_final(np_example)
return np_example
def crop_chains(
chains_list: List[FeatureDict],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int,
) -> List[FeatureDict]:
"""Crops the MSAs for a set of chains.
Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.
Returns:
The chains cropped.
"""
# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates,
)
cropped_chains.append(cropped_chain)
return cropped_chains
def _crop_single_chain(
chain: FeatureDict, msa_crop_size: int, pair_msa_sequences: bool, max_templates: int
) -> FeatureDict:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain["num_alignments"]
if pair_msa_sequences:
msa_size_all_seq = chain["num_alignments_all_seq"]
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain["msa_all_seq"][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)
)
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, msa_crop_size_all_seq)
# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)
include_templates = "template_aatype" in chain and max_templates
if include_templates:
num_templates = chain["template_aatype"].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)
for k in chain:
k_split = k.split("_all_seq")[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if "_all_seq" in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]
chain["num_alignments"] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain["num_templates"] = np.asarray(templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain["num_alignments_all_seq"] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32
)
return chain
def process_final(np_example: FeatureDict) -> FeatureDict:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example
def _make_seq_mask(np_example):
np_example["seq_mask"] = (np_example["entity_id"] > 0).astype(np.float32)
return np_example
def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""
np_example["msa_mask"] = np.ones_like(np_example["msa"], dtype=np.int8)
seq_mask = (np_example["entity_id"] > 0).astype(np.int8)
np_example["msa_mask"] *= seq_mask[None]
return np_example
def _filter_features(np_example: FeatureDict) -> FeatureDict:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
def process_unmerged_features(all_chain_features: MutableMapping[str, FeatureDict]):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features:
# Convert deletion matrices to float.
if "deletion_matrix_int" in chain_features:
chain_features["deletion_matrix"] = np.asarray(
chain_features.pop("deletion_matrix_int"), dtype=np.float32
)
if "deletion_matrix_int_all_seq" in chain_features:
chain_features["deletion_matrix_all_seq"] = np.asarray(
chain_features.pop("deletion_matrix_int_all_seq"), dtype=np.float32
)
chain_features["deletion_mean"] = np.mean(
chain_features["deletion_matrix"], axis=0
)
if "all_atom_positions" not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features["aatype"]
]
chain_features["all_atom_mask"] = all_atom_mask
chain_features["all_atom_positions"] = np.zeros(
list(all_atom_mask.shape) + [3]
)
# Add assembly_num_chains.
chain_features["assembly_num_chains"] = np.asarray(num_chains)
# Add entity_mask.
for chain_features in all_chain_features:
chain_features["entity_mask"] = (chain_features["entity_id"] != 0).astype(
np.int32
)
def empty_template_feats(n_res):
return {
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
"template_all_atom_positions": np.zeros((0, n_res, 37, 3)).astype(np.float32),
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
}
def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
if monomer_features["template_aatype"].shape[0] == 0:
monomer_features.update(
empty_template_feats(monomer_features["aatype"].shape[0])
)
converted = {}
unnecessary_leading_dim_feats = {
"sequence",
"domain_name",
"num_alignments",
"seq_length",
}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == "aatype":
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == "template_aatype":
if feature.shape[0] > 0:
feature = correct_template_restypes(feature)
elif feature_name == "template_all_atom_masks":
feature_name = "template_all_atom_mask"
elif feature_name == "msa":
feature = feature.astype(np.uint8)
if feature_name.endswith("_mask"):
feature = feature.astype(np.float32)
converted[feature_name] = feature
if "deletion_matrix_int" in monomer_features:
monomer_features["deletion_matrix"] = monomer_features.pop(
"deletion_matrix_int"
).astype(np.float32)
converted.pop(
"template_sum_probs"
) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right.
return converted
def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.
Args:
num: A positive integer.
Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f"Only positive integers allowed, got {num}.")
num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord("A")))
num = num // 26 - 1
return "".join(output)
def add_assembly_features(
all_chain_features,
):
"""Add features to distinguish between chains.
Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.
Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_features in all_chain_features:
assert "sequence" in chain_features
seq = str(chain_features["sequence"])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)
new_all_chain_features = []
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
num_sym = len(group_chain_features) # zy
for sym_id, chain_features in enumerate(group_chain_features, start=1):
seq_length = chain_features["seq_length"]
chain_features["asym_id"] = chain_id * np.ones(seq_length)
chain_features["sym_id"] = sym_id * np.ones(seq_length)
chain_features["entity_id"] = entity_id * np.ones(seq_length)
chain_features["num_sym"] = num_sym * np.ones(seq_length)
chain_id += 1
new_all_chain_features.append(chain_features)
return new_all_chain_features
def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example["msa"].shape[0]
if num_seq < min_num_seq:
for feat in ("msa", "deletion_matrix", "bert_mask", "msa_mask", "msa_chains"):
np_example[feat] = np.pad(
np_example[feat], ((0, min_num_seq - num_seq), (0, 0))
)
np_example["cluster_bias_mask"] = np.pad(
np_example["cluster_bias_mask"], ((0, min_num_seq - num_seq),)
)
return np_example
def post_process(np_example):
np_example = pad_msa(np_example, 512)
no_dim_keys = [
"num_alignments",
"assembly_num_chains",
"num_templates",
"seq_length",
"resolution",
]
for k in no_dim_keys:
if k in np_example:
np_example[k] = np_example[k].reshape(-1)
return np_example
def merge_msas(msa, del_mat, new_msa, new_del_mat):
cur_msa_set = set([tuple(m) for m in msa])
new_rows = []
for i, s in enumerate(new_msa):
if tuple(s) not in cur_msa_set:
new_rows.append(i)
ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0)
ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0)
return ret_msa, ret_del_mat
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional
from unifold.data import residue_constants
from Bio.PDB import PDBParser
import numpy as np
FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.
# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]
# 0-indexed number corresponding to the chain in the protein that this residue
# belongs to.
chain_index: np.ndarray # [num_res]
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} chains "
"because these cannot be written to PDB format."
)
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.
WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.
Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.
Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure("none", pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f"Only single model PDBs are supported. Found {len(models)} models."
)
model = models[0]
atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []
for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
for res in chain:
if res.id[2] != " ":
raise ValueError(
f"PDB contains an insertion code at chain {chain.id} and residue "
f"index {res.id[1]}. These are not supported."
)
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num
)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num,))
res_b_factors = np.zeros((residue_constants.atom_type_num,))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)
# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
)
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = "TER"
return (
f"{chain_end:<6}{atom_index:>5} {end_resname:>3} "
f"{chain_name:>1}{residue_index:>4}"
)
def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.
Args:
prot: The protein to convert to PDB.
Returns:
PDB string.
"""
restypes = residue_constants.restypes + ["X"]
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
atom_types = residue_constants.atom_types
pdb_lines = []
atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors
if np.any(aatype > residue_constants.restype_num):
raise ValueError("Invalid aatypes.")
# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f"The PDB format supports at most {PDB_MAX_CHAINS} chains."
)
chain_ids[i] = PDB_CHAIN_IDS[i]
pdb_lines.append("MODEL 1")
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1],
)
)
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.
res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
):
if mask < 0.5:
continue
record_type = "ATOM"
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
alt_loc = ""
insertion_code = ""
occupancy = 1.00
element = atom_name[0] # Protein supports only C, N, O, S, this works.
charge = ""
# PDB is a columnar format, every space matters here!
atom_line = (
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}"
f"{residue_index[i]:>4}{insertion_code:>1} "
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
f"{occupancy:>6.2f}{b_factor:>6.2f} "
f"{element:>2}{charge:>2}"
)
pdb_lines.append(atom_line)
atom_index += 1
# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1],
)
)
pdb_lines.append("ENDMDL")
pdb_lines.append("END")
# Pad all lines to 80 characters.
pdb_lines = [line.ljust(80) for line in pdb_lines]
return "\n".join(pdb_lines) + "\n" # Add terminating newline.
def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.
`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.
Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.
Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
def from_prediction(
features: FeatureDict, result: ModelOutput, b_factors: Optional[np.ndarray] = None
) -> Protein:
"""Assembles a protein from a prediction.
Args:
features: Dictionary holding model inputs.
fold_output: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if "asym_id" in features:
chain_index = features["asym_id"] - 1
else:
chain_index = np.zeros_like((features["aatype"]))
if b_factors is None:
b_factors = np.zeros_like(result["final_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=result["final_atom_positions"],
atom_mask=result["final_atom_mask"],
residue_index=features["residue_index"] + 1,
chain_index=chain_index,
b_factors=b_factors,
)
def from_feature(
features: FeatureDict, b_factors: Optional[np.ndarray] = None
) -> Protein:
"""Assembles a standard pdb from input atom positions & mask.
Args:
features: Dictionary holding model inputs.
b_factors: (Optional) B-factors to use for the protein.
Returns:
A protein instance.
"""
if "asym_id" in features:
chain_index = features["asym_id"] - 1
else:
chain_index = np.zeros_like((features["aatype"]))
if b_factors is None:
b_factors = np.zeros_like(features["all_atom_mask"])
return Protein(
aatype=features["aatype"],
atom_positions=features["all_atom_positions"],
atom_mask=features["all_atom_mask"],
residue_index=features["residue_index"] + 1,
chain_index=chain_index,
b_factors=b_factors,
)
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants used in AlphaFold."""
import collections
import functools
import os
from typing import List, Mapping, Tuple
import numpy as np
from unicore.utils import tree_map
# Distance from one CA to next CA [trans configuration: omega = 180].
ca_ca = 3.80209737096
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
# chi angles so their chi angle lists are empty.
chi_angles_atoms = {
"ALA": [],
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
"ARG": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "NE"],
["CG", "CD", "NE", "CZ"],
],
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
"CYS": [["N", "CA", "CB", "SG"]],
"GLN": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLU": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "OE1"],
],
"GLY": [],
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"LYS": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "CD"],
["CB", "CG", "CD", "CE"],
["CG", "CD", "CE", "NZ"],
],
"MET": [
["N", "CA", "CB", "CG"],
["CA", "CB", "CG", "SD"],
["CB", "CG", "SD", "CE"],
],
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
"SER": [["N", "CA", "CB", "OG"]],
"THR": [["N", "CA", "CB", "OG1"]],
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
"VAL": [["N", "CA", "CB", "CG1"]],
}
# If chi angles given in fixed-length array, this matrix determines how to mask
# them for each AA type. The order is as per restype_order (see below).
chi_angles_mask = [
[0.0, 0.0, 0.0, 0.0], # ALA
[1.0, 1.0, 1.0, 1.0], # ARG
[1.0, 1.0, 0.0, 0.0], # ASN
[1.0, 1.0, 0.0, 0.0], # ASP
[1.0, 0.0, 0.0, 0.0], # CYS
[1.0, 1.0, 1.0, 0.0], # GLN
[1.0, 1.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[1.0, 1.0, 0.0, 0.0], # HIS
[1.0, 1.0, 0.0, 0.0], # ILE
[1.0, 1.0, 0.0, 0.0], # LEU
[1.0, 1.0, 1.0, 1.0], # LYS
[1.0, 1.0, 1.0, 0.0], # MET
[1.0, 1.0, 0.0, 0.0], # PHE
[1.0, 1.0, 0.0, 0.0], # PRO
[1.0, 0.0, 0.0, 0.0], # SER
[1.0, 0.0, 0.0, 0.0], # THR
[1.0, 1.0, 0.0, 0.0], # TRP
[1.0, 1.0, 0.0, 0.0], # TYR
[1.0, 0.0, 0.0, 0.0], # VAL
]
# The following chi angles are pi periodic: they can be rotated by a multiple
# of pi without affecting the structure.
chi_pi_periodic = [
[0.0, 0.0, 0.0, 0.0], # ALA
[0.0, 0.0, 0.0, 0.0], # ARG
[0.0, 0.0, 0.0, 0.0], # ASN
[0.0, 1.0, 0.0, 0.0], # ASP
[0.0, 0.0, 0.0, 0.0], # CYS
[0.0, 0.0, 0.0, 0.0], # GLN
[0.0, 0.0, 1.0, 0.0], # GLU
[0.0, 0.0, 0.0, 0.0], # GLY
[0.0, 0.0, 0.0, 0.0], # HIS
[0.0, 0.0, 0.0, 0.0], # ILE
[0.0, 0.0, 0.0, 0.0], # LEU
[0.0, 0.0, 0.0, 0.0], # LYS
[0.0, 0.0, 0.0, 0.0], # MET
[0.0, 1.0, 0.0, 0.0], # PHE
[0.0, 0.0, 0.0, 0.0], # PRO
[0.0, 0.0, 0.0, 0.0], # SER
[0.0, 0.0, 0.0, 0.0], # THR
[0.0, 0.0, 0.0, 0.0], # TRP
[0.0, 1.0, 0.0, 0.0], # TYR
[0.0, 0.0, 0.0, 0.0], # VAL
[0.0, 0.0, 0.0, 0.0], # UNK
]
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
# psi and chi angles:
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
"ALA": [
["N", 0, (-0.525, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.529, -0.774, -1.205)],
["O", 3, (0.627, 1.062, 0.000)],
],
"ARG": [
["N", 0, (-0.524, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.524, -0.778, -1.209)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.616, 1.390, -0.000)],
["CD", 5, (0.564, 1.414, 0.000)],
["NE", 6, (0.539, 1.357, -0.000)],
["NH1", 7, (0.206, 2.301, 0.000)],
["NH2", 7, (2.078, 0.978, -0.000)],
["CZ", 7, (0.758, 1.093, -0.000)],
],
"ASN": [
["N", 0, (-0.536, 1.357, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.531, -0.787, -1.200)],
["O", 3, (0.625, 1.062, 0.000)],
["CG", 4, (0.584, 1.399, 0.000)],
["ND2", 5, (0.593, -1.188, 0.001)],
["OD1", 5, (0.633, 1.059, 0.000)],
],
"ASP": [
["N", 0, (-0.525, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, 0.000, -0.000)],
["CB", 0, (-0.526, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.593, 1.398, -0.000)],
["OD1", 5, (0.610, 1.091, 0.000)],
["OD2", 5, (0.592, -1.101, -0.003)],
],
"CYS": [
["N", 0, (-0.522, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, 0.000)],
["CB", 0, (-0.519, -0.773, -1.212)],
["O", 3, (0.625, 1.062, -0.000)],
["SG", 4, (0.728, 1.653, 0.000)],
],
"GLN": [
["N", 0, (-0.526, 1.361, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.525, -0.779, -1.207)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.615, 1.393, 0.000)],
["CD", 5, (0.587, 1.399, -0.000)],
["NE2", 6, (0.593, -1.189, -0.001)],
["OE1", 6, (0.634, 1.060, 0.000)],
],
"GLU": [
["N", 0, (-0.528, 1.361, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, -0.000, -0.000)],
["CB", 0, (-0.526, -0.781, -1.207)],
["O", 3, (0.626, 1.062, 0.000)],
["CG", 4, (0.615, 1.392, 0.000)],
["CD", 5, (0.600, 1.397, 0.000)],
["OE1", 6, (0.607, 1.095, -0.000)],
["OE2", 6, (0.589, -1.104, -0.001)],
],
"GLY": [
["N", 0, (-0.572, 1.337, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.517, -0.000, -0.000)],
["O", 3, (0.626, 1.062, -0.000)],
],
"HIS": [
["N", 0, (-0.527, 1.360, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.525, -0.778, -1.208)],
["O", 3, (0.625, 1.063, 0.000)],
["CG", 4, (0.600, 1.370, -0.000)],
["CD2", 5, (0.889, -1.021, 0.003)],
["ND1", 5, (0.744, 1.160, -0.000)],
["CE1", 5, (2.030, 0.851, 0.002)],
["NE2", 5, (2.145, -0.466, 0.004)],
],
"ILE": [
["N", 0, (-0.493, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.536, -0.793, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.534, 1.437, -0.000)],
["CG2", 4, (0.540, -0.785, -1.199)],
["CD1", 5, (0.619, 1.391, 0.000)],
],
"LEU": [
["N", 0, (-0.520, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.522, -0.773, -1.214)],
["O", 3, (0.625, 1.063, -0.000)],
["CG", 4, (0.678, 1.371, 0.000)],
["CD1", 5, (0.530, 1.430, -0.000)],
["CD2", 5, (0.535, -0.774, 1.200)],
],
"LYS": [
["N", 0, (-0.526, 1.362, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, 0.000)],
["CB", 0, (-0.524, -0.778, -1.208)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.619, 1.390, 0.000)],
["CD", 5, (0.559, 1.417, 0.000)],
["CE", 6, (0.560, 1.416, 0.000)],
["NZ", 7, (0.554, 1.387, 0.000)],
],
"MET": [
["N", 0, (-0.521, 1.364, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, 0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.210)],
["O", 3, (0.625, 1.062, -0.000)],
["CG", 4, (0.613, 1.391, -0.000)],
["SD", 5, (0.703, 1.695, 0.000)],
["CE", 6, (0.320, 1.786, -0.000)],
],
"PHE": [
["N", 0, (-0.518, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, 0.000, -0.000)],
["CB", 0, (-0.525, -0.776, -1.212)],
["O", 3, (0.626, 1.062, -0.000)],
["CG", 4, (0.607, 1.377, 0.000)],
["CD1", 5, (0.709, 1.195, -0.000)],
["CD2", 5, (0.706, -1.196, 0.000)],
["CE1", 5, (2.102, 1.198, -0.000)],
["CE2", 5, (2.098, -1.201, -0.000)],
["CZ", 5, (2.794, -0.003, -0.001)],
],
"PRO": [
["N", 0, (-0.566, 1.351, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, 0.000)],
["CB", 0, (-0.546, -0.611, -1.293)],
["O", 3, (0.621, 1.066, 0.000)],
["CG", 4, (0.382, 1.445, 0.0)],
# ['CD', 5, (0.427, 1.440, 0.0)],
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
],
"SER": [
["N", 0, (-0.529, 1.360, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, -0.000)],
["CB", 0, (-0.518, -0.777, -1.211)],
["O", 3, (0.626, 1.062, -0.000)],
["OG", 4, (0.503, 1.325, 0.000)],
],
"THR": [
["N", 0, (-0.517, 1.364, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.526, 0.000, -0.000)],
["CB", 0, (-0.516, -0.793, -1.215)],
["O", 3, (0.626, 1.062, 0.000)],
["CG2", 4, (0.550, -0.718, -1.228)],
["OG1", 4, (0.472, 1.353, 0.000)],
],
"TRP": [
["N", 0, (-0.521, 1.363, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.525, -0.000, 0.000)],
["CB", 0, (-0.523, -0.776, -1.212)],
["O", 3, (0.627, 1.062, 0.000)],
["CG", 4, (0.609, 1.370, -0.000)],
["CD1", 5, (0.824, 1.091, 0.000)],
["CD2", 5, (0.854, -1.148, -0.005)],
["CE2", 5, (2.186, -0.678, -0.007)],
["CE3", 5, (0.622, -2.530, -0.007)],
["NE1", 5, (2.140, 0.690, -0.004)],
["CH2", 5, (3.028, -2.890, -0.013)],
["CZ2", 5, (3.283, -1.543, -0.011)],
["CZ3", 5, (1.715, -3.389, -0.011)],
],
"TYR": [
["N", 0, (-0.522, 1.362, 0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.524, -0.000, -0.000)],
["CB", 0, (-0.522, -0.776, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG", 4, (0.607, 1.382, -0.000)],
["CD1", 5, (0.716, 1.195, -0.000)],
["CD2", 5, (0.713, -1.194, -0.001)],
["CE1", 5, (2.107, 1.200, -0.002)],
["CE2", 5, (2.104, -1.201, -0.003)],
["OH", 5, (4.168, -0.002, -0.005)],
["CZ", 5, (2.791, -0.001, -0.003)],
],
"VAL": [
["N", 0, (-0.494, 1.373, -0.000)],
["CA", 0, (0.000, 0.000, 0.000)],
["C", 0, (1.527, -0.000, -0.000)],
["CB", 0, (-0.533, -0.795, -1.213)],
["O", 3, (0.627, 1.062, -0.000)],
["CG1", 4, (0.540, 1.429, -0.000)],
["CG2", 4, (0.533, -0.776, 1.203)],
],
}
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
residue_atoms = {
"ALA": ["C", "CA", "CB", "N", "O"],
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
"GLY": ["C", "CA", "N", "O"],
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
"SER": ["C", "CA", "CB", "N", "O", "OG"],
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
"TRP": [
"C",
"CA",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
"N",
"NE1",
"O",
],
"TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
}
# Naming swaps for ambiguous atom names.
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
# 4 of the 20 amino acids.
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
# the 'ambiguous' atoms and their neighbours)
residue_atom_renaming_swaps = {
"ASP": {"OD1": "OD2"},
"GLU": {"OE1": "OE2"},
"PHE": {"CD1": "CD2", "CE1": "CE2"},
"TYR": {"CD1": "CD2", "CE1": "CE2"},
}
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
van_der_waals_radius = {
"C": 1.7,
"N": 1.55,
"O": 1.52,
"S": 1.8,
}
Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
BondAngle = collections.namedtuple(
"BondAngle", ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"]
)
@functools.lru_cache(maxsize=None)
def load_stereo_chemical_props() -> Tuple[
Mapping[str, List[Bond]], Mapping[str, List[Bond]], Mapping[str, List[BondAngle]]
]:
"""Load stereo_chemical_props.txt into a nice structure.
Load literature values for bond lengths and bond angles and translate
bond angles into the length of the opposite edge of the triangle
("residue_virtual_bonds").
Returns:
residue_bonds: Dict that maps resname -> list of Bond tuples.
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
"""
stereo_chemical_props_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "stereo_chemical_props.txt"
)
with open(stereo_chemical_props_path, "rt") as f:
stereo_chemical_props = f.read()
lines_iter = iter(stereo_chemical_props.splitlines())
# Load bond lengths.
residue_bonds = {}
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, length, stddev = line.split()
atom1, atom2 = bond.split("-")
if resname not in residue_bonds:
residue_bonds[resname] = []
residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev)))
residue_bonds["UNK"] = []
# Load bond angles.
residue_bond_angles = {}
next(lines_iter) # Skip empty line.
next(lines_iter) # Skip header line.
for line in lines_iter:
if line.strip() == "-":
break
bond, resname, angle_degree, stddev_degree = line.split()
atom1, atom2, atom3 = bond.split("-")
if resname not in residue_bond_angles:
residue_bond_angles[resname] = []
residue_bond_angles[resname].append(
BondAngle(
atom1,
atom2,
atom3,
float(angle_degree) / 180.0 * np.pi,
float(stddev_degree) / 180.0 * np.pi,
)
)
residue_bond_angles["UNK"] = []
def make_bond_key(atom1_name, atom2_name):
"""Unique key to lookup bonds."""
return "-".join(sorted([atom1_name, atom2_name]))
# Translate bond angles into distances ("virtual bonds").
residue_virtual_bonds = {}
for resname, bond_angles in residue_bond_angles.items():
# Create a fast lookup dict for bond lengths.
bond_cache = {}
for b in residue_bonds[resname]:
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
residue_virtual_bonds[resname] = []
for ba in bond_angles:
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
# Compute distance between atom1 and atom3 using the law of cosines
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
gamma = ba.angle_rad
length = np.sqrt(
bond1.length**2
+ bond2.length**2
- 2 * bond1.length * bond2.length * np.cos(gamma)
)
# Propagation of uncertainty assuming uncorrelated errors.
dl_outer = 0.5 / length
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
stddev = np.sqrt(
(dl_dgamma * ba.stddev) ** 2
+ (dl_db1 * bond1.stddev) ** 2
+ (dl_db2 * bond2.stddev) ** 2
)
residue_virtual_bonds[resname].append(
Bond(ba.atom1_name, ba.atom3name, length, stddev)
)
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
# Between-residue bond lengths for general bonds (first element) and for Proline
# (second element).
between_res_bond_length_c_n = [1.329, 1.341]
between_res_bond_length_stddev_c_n = [0.014, 0.016]
# Between-residue cos_angles.
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
# This mapping is used when we need to store atom data in a format that requires
# fixed atom data size for every residue (e.g. a numpy array).
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
"OXT",
]
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
atom_type_num = len(atom_types) # := 37.
# A compact atom encoding with 14 columns
# pylint: disable=line-too-long
# pylint: disable=bad-whitespace
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"NE1",
"CE2",
"CE3",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
# pylint: enable=line-too-long
# pylint: enable=bad-whitespace
# This is the standard residue order when coding AA type as a number.
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
restypes = [
"A",
"R",
"N",
"D",
"C",
"Q",
"E",
"G",
"H",
"I",
"L",
"K",
"M",
"F",
"P",
"S",
"T",
"W",
"Y",
"V",
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes) # := 20.
unk_restype_index = restype_num # Catch-all index for unknown restypes.
restypes_with_x = restypes + ["X"]
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
def sequence_to_onehot(
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
) -> np.ndarray:
"""Maps the given sequence into a one-hot encoded matrix.
Args:
sequence: An amino acid sequence.
mapping: A dictionary mapping amino acids to integers.
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
amino acid 'X', an error will be thrown. If False, any amino acid not in
the mapping will throw an error.
Returns:
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
the sequence.
Raises:
ValueError: If the mapping doesn't contain values from 0 to
num_unique_aas - 1 without any gaps.
"""
num_entries = max(mapping.values()) + 1
if sorted(set(mapping.values())) != list(range(num_entries)):
raise ValueError(
"The mapping must have values from 0 to num_unique_aas-1 "
"without any gaps. Got: %s" % sorted(mapping.values())
)
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
for aa_index, aa_type in enumerate(sequence):
if map_unknown_to_x:
if aa_type.isalpha() and aa_type.isupper():
aa_id = mapping.get(aa_type, mapping["X"])
else:
raise ValueError(f"Invalid character in the sequence: {aa_type}")
else:
aa_id = mapping[aa_type]
one_hot_arr[aa_index, aa_id] = 1
return one_hot_arr
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
}
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
# many more, and less common, three letter names as keys and maps many of these
# to the same one letter name (including 'X' and 'U' which we don't use here).
restype_3to1 = {v: k for k, v in restype_1to3.items()}
# Define a restype name for all unknown residues.
unk_restype = "UNK"
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
# remaining 20 amino acids are kept in alphabetical order.
# There are 2 non-amino acid codes, X (representing any amino acid) and
# "-" representing a missing amino acid in an alignment. The id for these
# codes is put at the end (20 and 21) so that they can easily be ignored if
# desired.
HHBLITS_AA_TO_ID = {
"A": 0,
"B": 2,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"J": 20,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"O": 20,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"U": 1,
"V": 17,
"W": 18,
"X": 20,
"Y": 19,
"Z": 3,
"-": 21,
}
# Partial inversion of HHBLITS_AA_TO_ID.
ID_TO_HHBLITS_AA = {
0: "A",
1: "C", # Also U.
2: "D", # Also B.
3: "E", # Also Z.
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X", # Includes J and O.
21: "-",
}
restypes_with_x_and_gap = restypes + ["X", "-"]
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
for i in range(len(restypes_with_x_and_gap))
)
def _make_standard_atom_mask() -> np.ndarray:
"""Returns [num_res_types, num_atom_types] mask array."""
# +1 to account for unknown (all 0s).
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
for restype, restype_letter in enumerate(restypes):
restype_name = restype_1to3[restype_letter]
atom_names = residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = atom_order[atom_name]
mask[restype, atom_type] = 1
return mask
STANDARD_ATOM_MASK = _make_standard_atom_mask()
# A one hot representation for the first and second atoms defining the axis
# of rotation for each chi-angle in each residue.
def chi_angle_atom(atom_index: int) -> np.ndarray:
"""Define chi-angle rigid groups via one-hot representations."""
chi_angles_index = {}
one_hots = []
for k, v in chi_angles_atoms.items():
indices = [atom_types.index(s[atom_index]) for s in v]
indices.extend([-1] * (4 - len(indices)))
chi_angles_index[k] = indices
for r in restypes:
res3 = restype_1to3[r]
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
one_hots.append(one_hot)
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
one_hot = np.stack(one_hots, axis=0)
one_hot = np.transpose(one_hot, [0, 2, 1])
return one_hot
chi_atom_1_one_hot = chi_angle_atom(1)
chi_atom_2_one_hot = chi_angle_atom(2)
# An array like chi_angles_atoms but using indices rather than names.
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
chi_angles_atom_indices = tree_map(
lambda n: atom_order[n], chi_angles_atom_indices, leaf_type=str
)
chi_angles_atom_indices = np.array(
[
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
for chi_atoms in chi_angles_atom_indices
]
)
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
# and atom index within that group.
chi_groups_for_atom = collections.defaultdict(list)
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
for atom_i, atom in enumerate(chi_group):
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
chi_groups_for_atom = dict(chi_groups_for_atom)
def _make_rigid_transformation_4x4(ex, ey, translation):
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
# Normalize ex.
ex_normalized = ex / np.linalg.norm(ex)
# make ey perpendicular to ex
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
ey_normalized /= np.linalg.norm(ey_normalized)
# compute ez as cross product
eznorm = np.cross(ex_normalized, ey_normalized)
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
return m
# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int_)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int_)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
def _make_rigid_group_constants():
"""Fill the arrays above."""
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
atomtype = atom_order[atomname]
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
restype_atom37_mask[restype, atomtype] = 1
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
restype_atom14_mask[restype, atom14idx] = 1
restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_positions = {
name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
}
# backbone to backbone is the identity transform
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
# pre-omega-frame to backbone (currently dummy identity matrix)
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
# phi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["N"] - atom_positions["CA"],
ey=np.array([1.0, 0.0, 0.0]),
translation=atom_positions["N"],
)
restype_rigid_group_default_frame[restype, 2, :, :] = mat
# psi-frame to backbone
mat = _make_rigid_transformation_4x4(
ex=atom_positions["C"] - atom_positions["CA"],
ey=atom_positions["CA"] - atom_positions["N"],
translation=atom_positions["C"],
)
restype_rigid_group_default_frame[restype, 3, :, :] = mat
# chi1-frame to backbone
if chi_angles_mask[restype][0]:
base_atom_names = chi_angles_atoms[resname][0]
base_atom_positions = [atom_positions[name] for name in base_atom_names]
mat = _make_rigid_transformation_4x4(
ex=base_atom_positions[2] - base_atom_positions[1],
ey=base_atom_positions[0] - base_atom_positions[1],
translation=base_atom_positions[2],
)
restype_rigid_group_default_frame[restype, 4, :, :] = mat
# chi2-frame to chi1-frame
# chi3-frame to chi2-frame
# chi4-frame to chi3-frame
# luckily all rotation axes for the next frame start at (0,0,0) of the
# previous frame
for chi_idx in range(1, 4):
if chi_angles_mask[restype][chi_idx]:
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
axis_end_atom_position = atom_positions[axis_end_atom_name]
mat = _make_rigid_transformation_4x4(
ex=axis_end_atom_position,
ey=np.array([-1.0, 0.0, 0.0]),
translation=axis_end_atom_position,
)
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
_make_rigid_group_constants()
def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15):
"""compute upper and lower bounds for bonds to assess violations."""
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
for restype, restype_letter in enumerate(restypes):
resname = restype_1to3[restype_letter]
atom_list = restype_name_to_atom14_names[resname]
# create lower and upper bounds for clashes
for atom1_idx, atom1_name in enumerate(atom_list):
if not atom1_name:
continue
atom1_radius = van_der_waals_radius[atom1_name[0]]
for atom2_idx, atom2_name in enumerate(atom_list):
if (not atom2_name) or atom1_idx == atom2_idx:
continue
atom2_radius = van_der_waals_radius[atom2_name[0]]
lower = atom1_radius + atom2_radius - overlap_tolerance
upper = 1e10
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
# overwrite lower and upper bounds for bonds and angles
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
atom1_idx = atom_list.index(b.atom1_name)
atom2_idx = atom_list.index(b.atom2_name)
lower = b.length - bond_length_tolerance_factor * b.stddev
upper = b.length + bond_length_tolerance_factor * b.stddev
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
return {
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
}
def _make_atom14_and_atom37_constants():
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in restypes:
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
restype_atom14_to_atom37.append(
[(atom_order[name] if name else 0) for name in atom_names]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in atom_types
]
)
restype_atom14_mask.append([(1.0 if name else 0.0) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask
(
restype_atom14_to_atom37,
restype_atom37_to_atom14,
restype_atom14_mask,
) = _make_atom14_and_atom37_constants()
def _make_renaming_matrices():
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [restype_1to3[res] for res in restypes]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {res: np.eye(14) for res in restype_3}
for resname, swap in residue_atom_renaming_swaps.items():
correspondences = np.arange(14)
for source_atom_swap, target_atom_swap in swap.items():
source_index = restype_name_to_atom14_names[resname].index(source_atom_swap)
target_index = restype_name_to_atom14_names[resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = np.zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3])
return renaming_matrices
renaming_matrices = _make_renaming_matrices()
def _make_atom14_is_ambiguous():
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = np.zeros((21, 14))
for resname, swap in residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = restype_order[restype_3to1[resname]]
atom_idx1 = restype_name_to_atom14_names[resname].index(atom_name1)
atom_idx2 = restype_name_to_atom14_names[resname].index(atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
return restype_atom14_is_ambiguous
restype_atom14_is_ambiguous = _make_atom14_is_ambiguous()
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in restypes:
residue_name = restype_1to3[residue_name]
residue_chi_angles = chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append([atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
chi_atom_indices = get_chi_atom_indices()
Bond Residue Mean StdDev
CA-CB ALA 1.520 0.021
N-CA ALA 1.459 0.020
CA-C ALA 1.525 0.026
C-O ALA 1.229 0.019
CA-CB ARG 1.535 0.022
CB-CG ARG 1.521 0.027
CG-CD ARG 1.515 0.025
CD-NE ARG 1.460 0.017
NE-CZ ARG 1.326 0.013
CZ-NH1 ARG 1.326 0.013
CZ-NH2 ARG 1.326 0.013
N-CA ARG 1.459 0.020
CA-C ARG 1.525 0.026
C-O ARG 1.229 0.019
CA-CB ASN 1.527 0.026
CB-CG ASN 1.506 0.023
CG-OD1 ASN 1.235 0.022
CG-ND2 ASN 1.324 0.025
N-CA ASN 1.459 0.020
CA-C ASN 1.525 0.026
C-O ASN 1.229 0.019
CA-CB ASP 1.535 0.022
CB-CG ASP 1.513 0.021
CG-OD1 ASP 1.249 0.023
CG-OD2 ASP 1.249 0.023
N-CA ASP 1.459 0.020
CA-C ASP 1.525 0.026
C-O ASP 1.229 0.019
CA-CB CYS 1.526 0.013
CB-SG CYS 1.812 0.016
N-CA CYS 1.459 0.020
CA-C CYS 1.525 0.026
C-O CYS 1.229 0.019
CA-CB GLU 1.535 0.022
CB-CG GLU 1.517 0.019
CG-CD GLU 1.515 0.015
CD-OE1 GLU 1.252 0.011
CD-OE2 GLU 1.252 0.011
N-CA GLU 1.459 0.020
CA-C GLU 1.525 0.026
C-O GLU 1.229 0.019
CA-CB GLN 1.535 0.022
CB-CG GLN 1.521 0.027
CG-CD GLN 1.506 0.023
CD-OE1 GLN 1.235 0.022
CD-NE2 GLN 1.324 0.025
N-CA GLN 1.459 0.020
CA-C GLN 1.525 0.026
C-O GLN 1.229 0.019
N-CA GLY 1.456 0.015
CA-C GLY 1.514 0.016
C-O GLY 1.232 0.016
CA-CB HIS 1.535 0.022
CB-CG HIS 1.492 0.016
CG-ND1 HIS 1.369 0.015
CG-CD2 HIS 1.353 0.017
ND1-CE1 HIS 1.343 0.025
CD2-NE2 HIS 1.415 0.021
CE1-NE2 HIS 1.322 0.023
N-CA HIS 1.459 0.020
CA-C HIS 1.525 0.026
C-O HIS 1.229 0.019
CA-CB ILE 1.544 0.023
CB-CG1 ILE 1.536 0.028
CB-CG2 ILE 1.524 0.031
CG1-CD1 ILE 1.500 0.069
N-CA ILE 1.459 0.020
CA-C ILE 1.525 0.026
C-O ILE 1.229 0.019
CA-CB LEU 1.533 0.023
CB-CG LEU 1.521 0.029
CG-CD1 LEU 1.514 0.037
CG-CD2 LEU 1.514 0.037
N-CA LEU 1.459 0.020
CA-C LEU 1.525 0.026
C-O LEU 1.229 0.019
CA-CB LYS 1.535 0.022
CB-CG LYS 1.521 0.027
CG-CD LYS 1.520 0.034
CD-CE LYS 1.508 0.025
CE-NZ LYS 1.486 0.025
N-CA LYS 1.459 0.020
CA-C LYS 1.525 0.026
C-O LYS 1.229 0.019
CA-CB MET 1.535 0.022
CB-CG MET 1.509 0.032
CG-SD MET 1.807 0.026
SD-CE MET 1.774 0.056
N-CA MET 1.459 0.020
CA-C MET 1.525 0.026
C-O MET 1.229 0.019
CA-CB PHE 1.535 0.022
CB-CG PHE 1.509 0.017
CG-CD1 PHE 1.383 0.015
CG-CD2 PHE 1.383 0.015
CD1-CE1 PHE 1.388 0.020
CD2-CE2 PHE 1.388 0.020
CE1-CZ PHE 1.369 0.019
CE2-CZ PHE 1.369 0.019
N-CA PHE 1.459 0.020
CA-C PHE 1.525 0.026
C-O PHE 1.229 0.019
CA-CB PRO 1.531 0.020
CB-CG PRO 1.495 0.050
CG-CD PRO 1.502 0.033
CD-N PRO 1.474 0.014
N-CA PRO 1.468 0.017
CA-C PRO 1.524 0.020
C-O PRO 1.228 0.020
CA-CB SER 1.525 0.015
CB-OG SER 1.418 0.013
N-CA SER 1.459 0.020
CA-C SER 1.525 0.026
C-O SER 1.229 0.019
CA-CB THR 1.529 0.026
CB-OG1 THR 1.428 0.020
CB-CG2 THR 1.519 0.033
N-CA THR 1.459 0.020
CA-C THR 1.525 0.026
C-O THR 1.229 0.019
CA-CB TRP 1.535 0.022
CB-CG TRP 1.498 0.018
CG-CD1 TRP 1.363 0.014
CG-CD2 TRP 1.432 0.017
CD1-NE1 TRP 1.375 0.017
NE1-CE2 TRP 1.371 0.013
CD2-CE2 TRP 1.409 0.012
CD2-CE3 TRP 1.399 0.015
CE2-CZ2 TRP 1.393 0.017
CE3-CZ3 TRP 1.380 0.017
CZ2-CH2 TRP 1.369 0.019
CZ3-CH2 TRP 1.396 0.016
N-CA TRP 1.459 0.020
CA-C TRP 1.525 0.026
C-O TRP 1.229 0.019
CA-CB TYR 1.535 0.022
CB-CG TYR 1.512 0.015
CG-CD1 TYR 1.387 0.013
CG-CD2 TYR 1.387 0.013
CD1-CE1 TYR 1.389 0.015
CD2-CE2 TYR 1.389 0.015
CE1-CZ TYR 1.381 0.013
CE2-CZ TYR 1.381 0.013
CZ-OH TYR 1.374 0.017
N-CA TYR 1.459 0.020
CA-C TYR 1.525 0.026
C-O TYR 1.229 0.019
CA-CB VAL 1.543 0.021
CB-CG1 VAL 1.524 0.021
CB-CG2 VAL 1.524 0.021
N-CA VAL 1.459 0.020
CA-C VAL 1.525 0.026
C-O VAL 1.229 0.019
-
Angle Residue Mean StdDev
N-CA-CB ALA 110.1 1.4
CB-CA-C ALA 110.1 1.5
N-CA-C ALA 111.0 2.7
CA-C-O ALA 120.1 2.1
N-CA-CB ARG 110.6 1.8
CB-CA-C ARG 110.4 2.0
CA-CB-CG ARG 113.4 2.2
CB-CG-CD ARG 111.6 2.6
CG-CD-NE ARG 111.8 2.1
CD-NE-CZ ARG 123.6 1.4
NE-CZ-NH1 ARG 120.3 0.5
NE-CZ-NH2 ARG 120.3 0.5
NH1-CZ-NH2 ARG 119.4 1.1
N-CA-C ARG 111.0 2.7
CA-C-O ARG 120.1 2.1
N-CA-CB ASN 110.6 1.8
CB-CA-C ASN 110.4 2.0
CA-CB-CG ASN 113.4 2.2
CB-CG-ND2 ASN 116.7 2.4
CB-CG-OD1 ASN 121.6 2.0
ND2-CG-OD1 ASN 121.9 2.3
N-CA-C ASN 111.0 2.7
CA-C-O ASN 120.1 2.1
N-CA-CB ASP 110.6 1.8
CB-CA-C ASP 110.4 2.0
CA-CB-CG ASP 113.4 2.2
CB-CG-OD1 ASP 118.3 0.9
CB-CG-OD2 ASP 118.3 0.9
OD1-CG-OD2 ASP 123.3 1.9
N-CA-C ASP 111.0 2.7
CA-C-O ASP 120.1 2.1
N-CA-CB CYS 110.8 1.5
CB-CA-C CYS 111.5 1.2
CA-CB-SG CYS 114.2 1.1
N-CA-C CYS 111.0 2.7
CA-C-O CYS 120.1 2.1
N-CA-CB GLU 110.6 1.8
CB-CA-C GLU 110.4 2.0
CA-CB-CG GLU 113.4 2.2
CB-CG-CD GLU 114.2 2.7
CG-CD-OE1 GLU 118.3 2.0
CG-CD-OE2 GLU 118.3 2.0
OE1-CD-OE2 GLU 123.3 1.2
N-CA-C GLU 111.0 2.7
CA-C-O GLU 120.1 2.1
N-CA-CB GLN 110.6 1.8
CB-CA-C GLN 110.4 2.0
CA-CB-CG GLN 113.4 2.2
CB-CG-CD GLN 111.6 2.6
CG-CD-OE1 GLN 121.6 2.0
CG-CD-NE2 GLN 116.7 2.4
OE1-CD-NE2 GLN 121.9 2.3
N-CA-C GLN 111.0 2.7
CA-C-O GLN 120.1 2.1
N-CA-C GLY 113.1 2.5
CA-C-O GLY 120.6 1.8
N-CA-CB HIS 110.6 1.8
CB-CA-C HIS 110.4 2.0
CA-CB-CG HIS 113.6 1.7
CB-CG-ND1 HIS 123.2 2.5
CB-CG-CD2 HIS 130.8 3.1
CG-ND1-CE1 HIS 108.2 1.4
ND1-CE1-NE2 HIS 109.9 2.2
CE1-NE2-CD2 HIS 106.6 2.5
NE2-CD2-CG HIS 109.2 1.9
CD2-CG-ND1 HIS 106.0 1.4
N-CA-C HIS 111.0 2.7
CA-C-O HIS 120.1 2.1
N-CA-CB ILE 110.8 2.3
CB-CA-C ILE 111.6 2.0
CA-CB-CG1 ILE 111.0 1.9
CB-CG1-CD1 ILE 113.9 2.8
CA-CB-CG2 ILE 110.9 2.0
CG1-CB-CG2 ILE 111.4 2.2
N-CA-C ILE 111.0 2.7
CA-C-O ILE 120.1 2.1
N-CA-CB LEU 110.4 2.0
CB-CA-C LEU 110.2 1.9
CA-CB-CG LEU 115.3 2.3
CB-CG-CD1 LEU 111.0 1.7
CB-CG-CD2 LEU 111.0 1.7
CD1-CG-CD2 LEU 110.5 3.0
N-CA-C LEU 111.0 2.7
CA-C-O LEU 120.1 2.1
N-CA-CB LYS 110.6 1.8
CB-CA-C LYS 110.4 2.0
CA-CB-CG LYS 113.4 2.2
CB-CG-CD LYS 111.6 2.6
CG-CD-CE LYS 111.9 3.0
CD-CE-NZ LYS 111.7 2.3
N-CA-C LYS 111.0 2.7
CA-C-O LYS 120.1 2.1
N-CA-CB MET 110.6 1.8
CB-CA-C MET 110.4 2.0
CA-CB-CG MET 113.3 1.7
CB-CG-SD MET 112.4 3.0
CG-SD-CE MET 100.2 1.6
N-CA-C MET 111.0 2.7
CA-C-O MET 120.1 2.1
N-CA-CB PHE 110.6 1.8
CB-CA-C PHE 110.4 2.0
CA-CB-CG PHE 113.9 2.4
CB-CG-CD1 PHE 120.8 0.7
CB-CG-CD2 PHE 120.8 0.7
CD1-CG-CD2 PHE 118.3 1.3
CG-CD1-CE1 PHE 120.8 1.1
CG-CD2-CE2 PHE 120.8 1.1
CD1-CE1-CZ PHE 120.1 1.2
CD2-CE2-CZ PHE 120.1 1.2
CE1-CZ-CE2 PHE 120.0 1.8
N-CA-C PHE 111.0 2.7
CA-C-O PHE 120.1 2.1
N-CA-CB PRO 103.3 1.2
CB-CA-C PRO 111.7 2.1
CA-CB-CG PRO 104.8 1.9
CB-CG-CD PRO 106.5 3.9
CG-CD-N PRO 103.2 1.5
CA-N-CD PRO 111.7 1.4
N-CA-C PRO 112.1 2.6
CA-C-O PRO 120.2 2.4
N-CA-CB SER 110.5 1.5
CB-CA-C SER 110.1 1.9
CA-CB-OG SER 111.2 2.7
N-CA-C SER 111.0 2.7
CA-C-O SER 120.1 2.1
N-CA-CB THR 110.3 1.9
CB-CA-C THR 111.6 2.7
CA-CB-OG1 THR 109.0 2.1
CA-CB-CG2 THR 112.4 1.4
OG1-CB-CG2 THR 110.0 2.3
N-CA-C THR 111.0 2.7
CA-C-O THR 120.1 2.1
N-CA-CB TRP 110.6 1.8
CB-CA-C TRP 110.4 2.0
CA-CB-CG TRP 113.7 1.9
CB-CG-CD1 TRP 127.0 1.3
CB-CG-CD2 TRP 126.6 1.3
CD1-CG-CD2 TRP 106.3 0.8
CG-CD1-NE1 TRP 110.1 1.0
CD1-NE1-CE2 TRP 109.0 0.9
NE1-CE2-CD2 TRP 107.3 1.0
CE2-CD2-CG TRP 107.3 0.8
CG-CD2-CE3 TRP 133.9 0.9
NE1-CE2-CZ2 TRP 130.4 1.1
CE3-CD2-CE2 TRP 118.7 1.2
CD2-CE2-CZ2 TRP 122.3 1.2
CE2-CZ2-CH2 TRP 117.4 1.0
CZ2-CH2-CZ3 TRP 121.6 1.2
CH2-CZ3-CE3 TRP 121.2 1.1
CZ3-CE3-CD2 TRP 118.8 1.3
N-CA-C TRP 111.0 2.7
CA-C-O TRP 120.1 2.1
N-CA-CB TYR 110.6 1.8
CB-CA-C TYR 110.4 2.0
CA-CB-CG TYR 113.4 1.9
CB-CG-CD1 TYR 121.0 0.6
CB-CG-CD2 TYR 121.0 0.6
CD1-CG-CD2 TYR 117.9 1.1
CG-CD1-CE1 TYR 121.3 0.8
CG-CD2-CE2 TYR 121.3 0.8
CD1-CE1-CZ TYR 119.8 0.9
CD2-CE2-CZ TYR 119.8 0.9
CE1-CZ-CE2 TYR 119.8 1.6
CE1-CZ-OH TYR 120.1 2.7
CE2-CZ-OH TYR 120.1 2.7
N-CA-C TYR 111.0 2.7
CA-C-O TYR 120.1 2.1
N-CA-CB VAL 111.5 2.2
CB-CA-C VAL 111.4 1.9
CA-CB-CG1 VAL 110.9 1.5
CA-CB-CG2 VAL 110.9 1.5
CG1-CB-CG2 VAL 110.9 1.6
N-CA-C VAL 111.0 2.7
CA-C-O VAL 120.1 2.1
-
Non-bonded distance Minimum Dist Tolerance
C-C 3.4 1.5
C-N 3.25 1.5
C-S 3.5 1.5
C-O 3.22 1.5
N-N 3.1 1.5
N-S 3.35 1.5
N-O 3.07 1.5
O-S 3.32 1.5
O-O 3.04 1.5
S-S 2.03 1.0
-
import copy as copy_lib
import functools
import gzip
import json
import numpy as np
import pickle
from scipy import sparse as sp
from typing import *
from . import residue_constants as rc
from .data_ops import NumpyDict
def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False):
if deepcopy:
def decorator(f):
cached_func = functools.lru_cache(maxsize, typed)(f)
@functools.wraps(f)
def wrapper(*args, **kwargs):
return copy_lib.deepcopy(cached_func(*args, **kwargs))
return wrapper
elif copy:
def decorator(f):
cached_func = functools.lru_cache(maxsize, typed)(f)
@functools.wraps(f)
def wrapper(*args, **kwargs):
return copy_lib.copy(cached_func(*args, **kwargs))
return wrapper
else:
decorator = functools.lru_cache(maxsize, typed)
return decorator
@lru_cache(maxsize=8, deepcopy=True)
def load_pickle_safe(path: str) -> Dict[str, Any]:
def load(path):
assert path.endswith(".pkl") or path.endswith(
".pkl.gz"
), f"bad suffix in {path} as pickle file."
open_fn = gzip.open if path.endswith(".gz") else open
with open_fn(path, "rb") as f:
return pickle.load(f)
ret = load(path)
ret = uncompress_features(ret)
return ret
@lru_cache(maxsize=8, copy=True)
def load_pickle(path: str) -> Dict[str, Any]:
def load(path):
assert path.endswith(".pkl") or path.endswith(
".pkl.gz"
), f"bad suffix in {path} as pickle file."
open_fn = gzip.open if path.endswith(".gz") else open
with open_fn(path, "rb") as f:
return pickle.load(f)
ret = load(path)
ret = uncompress_features(ret)
return ret
def correct_template_restypes(feature):
"""Correct template restype to have the same order as residue_constants."""
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
return feature
def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict:
feature["msa"] = feature["msa"].astype(np.uint8)
if "num_alignments" in feature:
feature.pop("num_alignments")
make_all_seq_key = lambda k: f"{k}_all_seq" if not k.endswith("_all_seq") else k
return {make_all_seq_key(k): v for k, v in feature.items()}
def to_dense_matrix(spmat_dict: NumpyDict):
spmat = sp.coo_matrix(
(spmat_dict["data"], (spmat_dict["row"], spmat_dict["col"])),
shape=spmat_dict["shape"],
dtype=np.float32,
)
return spmat.toarray()
FEATS_DTYPE = {"msa": np.int32}
def uncompress_features(feats: NumpyDict) -> NumpyDict:
if "sparse_deletion_matrix_int" in feats:
v = feats.pop("sparse_deletion_matrix_int")
v = to_dense_matrix(v)
feats["deletion_matrix"] = v
return feats
def filter(feature: NumpyDict, **kwargs) -> NumpyDict:
assert len(kwargs) == 1, f"wrong usage of filter with kwargs: {kwargs}"
if "desired_keys" in kwargs:
feature = {k: v for k, v in feature.items() if k in kwargs["desired_keys"]}
elif "required_keys" in kwargs:
for k in kwargs["required_keys"]:
assert k in feature, f"cannot find required key {k}."
elif "ignored_keys" in kwargs:
feature = {k: v for k, v in feature.items() if k not in kwargs["ignored_keys"]}
else:
raise AssertionError(f"wrong usage of filter with kwargs: {kwargs}")
return feature
def compress_features(features: NumpyDict):
change_dtype = {
"msa": np.uint8,
}
sparse_keys = ["deletion_matrix_int"]
compressed_features = {}
for k, v in features.items():
if k in change_dtype:
v = v.astype(change_dtype[k])
if k in sparse_keys:
v = sp.coo_matrix(v, dtype=v.dtype)
sp_v = {"shape": v.shape, "row": v.row, "col": v.col, "data": v.data}
k = f"sparse_{k}"
v = sp_v
compressed_features[k] = v
return compressed_features
import os
import json
import ml_collections as mlc
import numpy as np
import copy
import torch
from typing import *
from unifold.data import utils
from unifold.data.data_ops import NumpyDict, TorchDict
from unifold.data.process import process_features, process_labels
from unifold.data.process_multimer import (
pair_and_merge,
add_assembly_features,
convert_monomer_features,
post_process,
merge_msas,
)
from unicore.data import UnicoreDataset, data_utils
from unicore.distributed import utils as distributed_utils
Rotation = Iterable[Iterable]
Translation = Iterable
Operation = Union[str, Tuple[Rotation, Translation]]
NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]]
TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]]
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def make_data_config(
config: mlc.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[mlc.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if cfg.common.is_multimer:
feature_names += cfg.common.multimer_features
if cfg[mode].supervised:
feature_names += cfg.supervised.supervised_features
return cfg, feature_names
def process_label(all_atom_positions: np.ndarray, operation: Operation) -> np.ndarray:
if operation == "I":
return all_atom_positions
rot, trans = operation
rot = np.array(rot).reshape(3, 3)
trans = np.array(trans).reshape(3)
return all_atom_positions @ rot.T + trans
@utils.lru_cache(maxsize=8, copy=True)
def load_single_feature(
sequence_id: str,
monomer_feature_dir: str,
uniprot_msa_dir: Optional[str] = None,
is_monomer: bool = False,
) -> NumpyDict:
monomer_feature = utils.load_pickle(
os.path.join(monomer_feature_dir, f"{sequence_id}.feature.pkl.gz")
)
monomer_feature = convert_monomer_features(monomer_feature)
chain_feature = {**monomer_feature}
if uniprot_msa_dir is not None:
all_seq_feature = utils.load_pickle(
os.path.join(uniprot_msa_dir, f"{sequence_id}.uniprot.pkl.gz")
)
if is_monomer:
chain_feature["msa"], chain_feature["deletion_matrix"] = merge_msas(
chain_feature["msa"],
chain_feature["deletion_matrix"],
all_seq_feature["msa"],
all_seq_feature["deletion_matrix"],
)
else:
all_seq_feature = utils.convert_all_seq_feature(all_seq_feature)
for key in [
"msa_all_seq",
"msa_species_identifiers_all_seq",
"deletion_matrix_all_seq",
]:
chain_feature[key] = all_seq_feature[key]
return chain_feature
def load_single_label(
label_id: str,
label_dir: str,
symmetry_operation: Optional[Operation] = None,
) -> NumpyDict:
label = utils.load_pickle(os.path.join(label_dir, f"{label_id}.label.pkl.gz"))
if symmetry_operation is not None:
label["all_atom_positions"] = process_label(
label["all_atom_positions"], symmetry_operation
)
label = {
k: v
for k, v in label.items()
if k in ["aatype", "all_atom_positions", "all_atom_mask", "resolution"]
}
return label
def load(
sequence_ids: List[str],
monomer_feature_dir: str,
uniprot_msa_dir: Optional[str] = None,
label_ids: Optional[List[str]] = None,
label_dir: Optional[str] = None,
symmetry_operations: Optional[List[Operation]] = None,
is_monomer: bool = False,
) -> NumpyExample:
all_chain_features = [
load_single_feature(s, monomer_feature_dir, uniprot_msa_dir, is_monomer)
for s in sequence_ids
]
if label_ids is not None:
# load labels
assert len(label_ids) == len(sequence_ids)
assert label_dir is not None
if symmetry_operations is None:
symmetry_operations = ["I" for _ in label_ids]
all_chain_labels = [
load_single_label(l, label_dir, o)
for l, o in zip(label_ids, symmetry_operations)
]
# update labels into features to calculate spatial cropping etc.
[f.update(l) for f, l in zip(all_chain_features, all_chain_labels)]
all_chain_features = add_assembly_features(all_chain_features)
# get labels back from features, as add_assembly_features may alter the order of inputs.
if label_ids is not None:
all_chain_labels = [
{
k: f[k]
for k in ["aatype", "all_atom_positions", "all_atom_mask", "resolution"]
}
for f in all_chain_features
]
else:
all_chain_labels = None
asym_len = np.array([c["seq_length"] for c in all_chain_features], dtype=np.int64)
if is_monomer:
all_chain_features = all_chain_features[0]
else:
all_chain_features = pair_and_merge(all_chain_features)
all_chain_features = post_process(all_chain_features)
all_chain_features["asym_len"] = asym_len
return all_chain_features, all_chain_labels
def process(
config: mlc.ConfigDict,
mode: str,
features: NumpyDict,
labels: Optional[List[NumpyDict]] = None,
seed: int = 0,
batch_idx: Optional[int] = None,
data_idx: Optional[int] = None,
is_distillation: bool = False,
) -> TorchExample:
if mode == "train":
assert batch_idx is not None
with data_utils.numpy_seed(seed, batch_idx, key="recycling"):
num_iters = np.random.randint(0, config.common.max_recycling_iters + 1)
use_clamped_fape = np.random.rand() < config[mode].use_clamped_fape_prob
else:
num_iters = config.common.max_recycling_iters
use_clamped_fape = 1
features["num_recycling_iters"] = int(num_iters)
features["use_clamped_fape"] = int(use_clamped_fape)
features["is_distillation"] = int(is_distillation)
if is_distillation and "msa_chains" in features:
features.pop("msa_chains")
num_res = int(features["seq_length"])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if labels is not None:
features["resolution"] = labels[0]["resolution"].reshape(-1)
with data_utils.numpy_seed(seed, data_idx, key="protein_feature"):
features["crop_and_fix_size_seed"] = np.random.randint(0, 63355)
features = utils.filter(features, desired_keys=feature_names)
features = {k: torch.tensor(v) for k, v in features.items()}
with torch.no_grad():
features = process_features(features, cfg.common, cfg[mode])
if labels is not None:
labels = [{k: torch.tensor(v) for k, v in l.items()} for l in labels]
with torch.no_grad():
labels = process_labels(labels)
return features, labels
def load_and_process(
config: mlc.ConfigDict,
mode: str,
seed: int = 0,
batch_idx: Optional[int] = None,
data_idx: Optional[int] = None,
is_distillation: bool = False,
**load_kwargs,
):
is_monomer = (
is_distillation
if "is_monomer" not in load_kwargs
else load_kwargs.pop("is_monomer")
)
features, labels = load(**load_kwargs, is_monomer=is_monomer)
features, labels = process(
config, mode, features, labels, seed, batch_idx, data_idx, is_distillation
)
return features, labels
class UnifoldDataset(UnicoreDataset):
def __init__(
self,
args,
seed,
config,
data_path,
mode="train",
max_step=None,
disable_sd=False,
json_prefix="",
):
self.path = data_path
def load_json(filename):
return json.load(open(filename, "r"))
sample_weight = load_json(
os.path.join(self.path, json_prefix + mode + "_sample_weight.json")
)
self.multi_label = load_json(
os.path.join(self.path, json_prefix + mode + "_multi_label.json")
)
self.inverse_multi_label = self._inverse_map(self.multi_label)
self.sample_weight = {}
for chain in self.inverse_multi_label:
entity = self.inverse_multi_label[chain]
self.sample_weight[chain] = sample_weight[entity]
self.seq_sample_weight = sample_weight
logger.info(
"load {} chains (unique {} sequences)".format(
len(self.sample_weight), len(self.seq_sample_weight)
)
)
self.feature_path = os.path.join(self.path, "pdb_features")
self.label_path = os.path.join(self.path, "pdb_labels")
sd_sample_weight_path = os.path.join(
self.path, json_prefix + "sd_train_sample_weight.json"
)
if mode == "train" and os.path.isfile(sd_sample_weight_path) and not disable_sd:
self.sd_sample_weight = load_json(sd_sample_weight_path)
logger.info(
"load {} self-distillation samples.".format(len(self.sd_sample_weight))
)
self.sd_feature_path = os.path.join(self.path, "sd_features")
self.sd_label_path = os.path.join(self.path, "sd_labels")
else:
self.sd_sample_weight = None
self.batch_size = (
args.batch_size
* distributed_utils.get_data_parallel_world_size()
* args.update_freq[0]
)
self.data_len = (
max_step * self.batch_size
if max_step is not None
else len(self.sample_weight)
)
self.mode = mode
self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight(
self.seq_sample_weight
)
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
self.sample_weight
)
if self.sd_sample_weight is not None:
(
self.sd_num_chain,
self.sd_chain_keys,
self.sd_sample_prob,
) = self.cal_sample_weight(self.sd_sample_weight)
self.config = config.data
self.seed = seed
self.sd_prob = args.sd_prob
def cal_sample_weight(self, sample_weight):
prot_keys = list(sample_weight.keys())
sum_weight = sum(sample_weight.values())
sample_prob = [sample_weight[k] / sum_weight for k in prot_keys]
num_prot = len(prot_keys)
return num_prot, prot_keys, sample_prob
def sample_chain(self, idx, sample_by_seq=False):
is_distillation = False
if self.mode == "train":
with data_utils.numpy_seed(self.seed, idx, key="data_sample"):
is_distillation = (
(np.random.rand(1)[0] < self.sd_prob)
if self.sd_sample_weight is not None
else False
)
if is_distillation:
prot_idx = np.random.choice(
self.sd_num_chain, p=self.sd_sample_prob
)
label_name = self.sd_chain_keys[prot_idx]
seq_name = label_name
else:
if not sample_by_seq:
prot_idx = np.random.choice(self.num_chain, p=self.sample_prob)
label_name = self.chain_keys[prot_idx]
seq_name = self.inverse_multi_label[label_name]
else:
seq_idx = np.random.choice(self.num_seq, p=self.seq_sample_prob)
seq_name = self.seq_keys[seq_idx]
label_name = np.random.choice(self.multi_label[seq_name])
else:
label_name = self.chain_keys[idx]
seq_name = self.inverse_multi_label[label_name]
return seq_name, label_name, is_distillation
def __getitem__(self, idx):
sequence_id, label_id, is_distillation = self.sample_chain(
idx, sample_by_seq=True
)
feature_dir, label_dir = (
(self.feature_path, self.label_path)
if not is_distillation
else (self.sd_feature_path, self.sd_label_path)
)
features, _ = load_and_process(
self.config,
self.mode,
self.seed,
batch_idx=(idx // self.batch_size),
data_idx=idx,
is_distillation=is_distillation,
sequence_ids=[sequence_id],
monomer_feature_dir=feature_dir,
uniprot_msa_dir=None,
label_ids=[label_id],
label_dir=label_dir,
symmetry_operations=None,
is_monomer=True,
)
return features
def __len__(self):
return self.data_len
@staticmethod
def collater(samples):
# first dim is recyling. bsz is at the 2nd dim
return data_utils.collate_dict(samples, dim=1)
@staticmethod
def _inverse_map(mapping: Dict[str, List[str]]):
inverse_mapping = {}
for ent, refs in mapping.items():
for ref in refs:
if ref in inverse_mapping: # duplicated ent for this ref.
ent_2 = inverse_mapping[ref]
assert (
ent == ent_2
), f"multiple entities ({ent_2}, {ent}) exist for reference {ref}."
inverse_mapping[ref] = ent
return inverse_mapping
class UnifoldMultimerDataset(UnifoldDataset):
def __init__(
self,
args: mlc.ConfigDict,
seed: int,
config: mlc.ConfigDict,
data_path: str,
mode: str = "train",
max_step: Optional[int] = None,
disable_sd: bool = False,
json_prefix: str = "",
**kwargs,
):
super().__init__(
args, seed, config, data_path, mode, max_step, disable_sd, json_prefix
)
self.data_path = data_path
self.pdb_assembly = json.load(
open(os.path.join(self.data_path, json_prefix + "pdb_assembly.json"))
)
self.pdb_chains = self.get_chains(self.inverse_multi_label)
self.monomer_feature_path = os.path.join(self.data_path, "pdb_features")
self.uniprot_msa_path = os.path.join(self.data_path, "pdb_uniprots")
self.label_path = os.path.join(self.data_path, "pdb_labels")
self.max_chains = args.max_chains
if self.mode == "train":
self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains(
self.pdb_chains, self.pdb_assembly, self.sample_weight, self.max_chains
)
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
self.sample_weight
)
def __getitem__(self, idx):
seq_id, label_id, is_distillation = self.sample_chain(idx)
if is_distillation:
label_ids = [label_id]
sequence_ids = [seq_id]
monomer_feature_path, uniprot_msa_path, label_path = (
self.sd_feature_path,
None,
self.sd_label_path,
)
symmetry_operations = None
else:
pdb_id = self.get_pdb_name(label_id)
if pdb_id in self.pdb_assembly and self.mode == "train":
label_ids = [
pdb_id + "_" + id for id in self.pdb_assembly[pdb_id]["chains"]
]
symmetry_operations = [t for t in self.pdb_assembly[pdb_id]["opers"]]
else:
label_ids = self.pdb_chains[pdb_id]
symmetry_operations = None
sequence_ids = [
self.inverse_multi_label[chain_id] for chain_id in label_ids
]
monomer_feature_path, uniprot_msa_path, label_path = (
self.monomer_feature_path,
self.uniprot_msa_path,
self.label_path,
)
return load_and_process(
self.config,
self.mode,
self.seed,
batch_idx=(idx // self.batch_size),
data_idx=idx,
is_distillation=is_distillation,
sequence_ids=sequence_ids,
monomer_feature_dir=monomer_feature_path,
uniprot_msa_dir=uniprot_msa_path,
label_ids=label_ids,
label_dir=label_path,
symmetry_operations=symmetry_operations,
is_monomer=False,
)
@staticmethod
def collater(samples):
# first dim is recyling. bsz is at the 2nd dim
if len(samples) <= 0: # tackle empty batch
return None
feats = [s[0] for s in samples]
labs = [s[1] for s in samples if s[1] is not None]
try:
feats = data_utils.collate_dict(feats, dim=1)
except:
raise ValueError("cannot collate features", feats)
if not labs:
labs = None
return feats, labs
@staticmethod
def get_pdb_name(chain):
return chain.split("_")[0]
@staticmethod
def get_chains(canon_chain_map):
pdb_chains = {}
for chain in canon_chain_map:
pdb = UnifoldMultimerDataset.get_pdb_name(chain)
if pdb not in pdb_chains:
pdb_chains[pdb] = []
pdb_chains[pdb].append(chain)
return pdb_chains
@staticmethod
def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight, max_chains):
new_pdb_chains = {}
for chain in pdb_chains:
if chain in pdb_assembly:
size = len(pdb_assembly[chain]["chains"])
if size <= max_chains:
new_pdb_chains[chain] = pdb_chains[chain]
else:
size = len(pdb_chains[chain])
if size == 1:
new_pdb_chains[chain] = pdb_chains[chain]
new_sample_weight = {
k: sample_weight[k]
for k in sample_weight
if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains
}
logger.info(
f"filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs "
f"({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) "
f"by max_chains {max_chains}"
)
return new_pdb_chains, new_sample_weight
# Copyright 2022 DP Technology
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run CPU MSA & template searching to get pickled features."""
import json
import os
import pickle
from pathlib import Path
import shutil
import time
import gzip
from absl import app
from absl import flags
from absl import logging
from unifold.data.utils import compress_features
from unifold.msa import parsers
from unifold.msa import pipeline
from unifold.msa import templates
from unifold.msa.utils import divide_multi_chains
from unifold.msa.tools import hmmsearch
logging.set_verbosity(logging.INFO)
flags.DEFINE_string(
"fasta_path",
None,
"Path to FASTA file, If a FASTA file contains multiple sequences, "
"then it will be divided into several single sequences. ",
)
flags.DEFINE_string(
"output_dir", None, "Path to a directory that will " "store the results."
)
flags.DEFINE_string(
"jackhmmer_binary_path",
shutil.which("jackhmmer"),
"Path to the JackHMMER executable.",
)
flags.DEFINE_string(
"hhblits_binary_path", shutil.which("hhblits"), "Path to the HHblits executable."
)
flags.DEFINE_string(
"hhsearch_binary_path", shutil.which("hhsearch"), "Path to the HHsearch executable."
)
flags.DEFINE_string(
"hmmsearch_binary_path",
shutil.which("hmmsearch"),
"Path to the hmmsearch executable.",
)
flags.DEFINE_string(
"hmmbuild_binary_path", shutil.which("hmmbuild"), "Path to the hmmbuild executable."
)
flags.DEFINE_string(
"kalign_binary_path", shutil.which("kalign"), "Path to the Kalign executable."
)
flags.DEFINE_string(
"uniref90_database_path",
None,
"Path to the Uniref90 database for use by JackHMMER.",
)
flags.DEFINE_string(
"mgnify_database_path", None, "Path to the MGnify database for use by JackHMMER."
)
flags.DEFINE_string(
"bfd_database_path", None, "Path to the BFD database for use by HHblits."
)
flags.DEFINE_string(
"small_bfd_database_path",
None,
'Path to the small version of BFD used with the "reduced_dbs" preset.',
)
flags.DEFINE_string(
"uniclust30_database_path",
None,
"Path to the Uniclust30 " "database for use by HHblits.",
)
flags.DEFINE_string(
"uniprot_database_path",
None,
"Path to the Uniprot database for use by JackHMMer.",
)
flags.DEFINE_string(
"pdb_seqres_database_path",
None,
"Path to the PDB seqres database for use by hmmsearch.",
)
flags.DEFINE_string(
"template_mmcif_dir",
None,
"Path to a directory with template mmCIF structures, each named " "<pdb_id>.cif",
)
flags.DEFINE_string(
"max_template_date",
None,
"Maximum template release date to consider. Important if folding "
"historical test sets.",
)
flags.DEFINE_string(
"obsolete_pdbs_path",
None,
"Path to file containing a mapping from obsolete PDB IDs to the PDB IDs "
"of their replacements.",
)
flags.DEFINE_enum(
"db_preset",
"full_dbs",
["full_dbs", "reduced_dbs"],
"Choose preset MSA database configuration - smaller genetic database "
"config (reduced_dbs) or full genetic database config (full_dbs)",
)
flags.DEFINE_boolean(
"use_precomputed_msas",
True,
"Whether to read MSAs that have been written to disk instead of running "
"the MSA tools. The MSA files are looked up in the output directory, "
"so it must stay the same between multiple runs that are to reuse the "
"MSAs. WARNING: This will not check if the sequence, database or "
"configuration have changed.",
)
flags.DEFINE_boolean("use_uniprot", True, "Whether to use UniProt MSAs.")
FLAGS = flags.FLAGS
MAX_TEMPLATE_HITS = 20
def _check_flag(flag_name: str, other_flag_name: str, should_be_set: bool):
if should_be_set != bool(FLAGS[flag_name].value):
verb = "be" if should_be_set else "not be"
raise ValueError(
f"{flag_name} must {verb} set when running with "
f'"--{other_flag_name}={FLAGS[other_flag_name].value}".'
)
def generate_pkl_features(
fasta_path: str,
fasta_name: str,
output_dir_base: str,
data_pipeline: pipeline.DataPipeline,
use_uniprot: bool,
):
"""
Predicts structure using AlphaFold for the given sequence.
"""
logging.info(f"searching homogeneous Sequences & structures for {fasta_name}...")
timings = {}
output_dir = os.path.join(output_dir_base, fasta_name.split("_")[0])
if not os.path.exists(output_dir):
os.makedirs(output_dir)
chain_id = fasta_name.split("_")[1] if len(fasta_name.split("_")) > 1 else "A"
msa_output_dir = os.path.join(output_dir, chain_id)
if not os.path.exists(msa_output_dir):
os.makedirs(msa_output_dir)
# Get features.
features_output_path = os.path.join(
output_dir, "{}.feature.pkl.gz".format(chain_id)
)
if not os.path.exists(features_output_path):
t_0 = time.time()
feature_dict = data_pipeline.process(
input_fasta_path=fasta_path, msa_output_dir=msa_output_dir
)
timings["features"] = time.time() - t_0
feature_dict = compress_features(feature_dict)
pickle.dump(feature_dict, gzip.GzipFile(features_output_path, "wb"), protocol=4)
# Get uniprot
if use_uniprot:
uniprot_output_path = os.path.join(
output_dir, "{}.uniprot.pkl.gz".format(chain_id)
)
if not os.path.exists(uniprot_output_path):
t_0 = time.time()
all_seq_feature_dict = data_pipeline.process_uniprot(
input_fasta_path=fasta_path, msa_output_dir=msa_output_dir
)
timings["all_seq_features"] = time.time() - t_0
all_seq_feature_dict = compress_features(all_seq_feature_dict)
pickle.dump(
all_seq_feature_dict,
gzip.GzipFile(uniprot_output_path, "wb"),
protocol=4,
)
logging.info("Final timings for %s: %s", fasta_name, timings)
timings_output_path = os.path.join(output_dir, "{}.timings.json".format(chain_id))
with open(timings_output_path, "w") as f:
f.write(json.dumps(timings, indent=4))
def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
for tool_name in (
"jackhmmer",
"hhblits",
"hhsearch",
"hmmsearch",
"hmmbuild",
"kalign",
):
if not FLAGS[f"{tool_name}_binary_path"].value:
raise ValueError(
f'Could not find path to the "{tool_name}" binary. Make '
"sure it is installed on your system."
)
use_small_bfd = FLAGS.db_preset == "reduced_dbs"
_check_flag("small_bfd_database_path", "db_preset", should_be_set=use_small_bfd)
_check_flag("bfd_database_path", "db_preset", should_be_set=not use_small_bfd)
_check_flag(
"uniclust30_database_path", "db_preset", should_be_set=not use_small_bfd
)
template_searcher = hmmsearch.Hmmsearch(
binary_path=FLAGS.hmmsearch_binary_path,
hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
database_path=FLAGS.pdb_seqres_database_path,
)
template_featurizer = templates.HmmsearchHitFeaturizer(
mmcif_dir=FLAGS.template_mmcif_dir,
max_template_date=FLAGS.max_template_date,
max_hits=MAX_TEMPLATE_HITS,
kalign_binary_path=FLAGS.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=FLAGS.obsolete_pdbs_path,
)
data_pipeline = pipeline.DataPipeline(
jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
hhblits_binary_path=FLAGS.hhblits_binary_path,
uniref90_database_path=FLAGS.uniref90_database_path,
mgnify_database_path=FLAGS.mgnify_database_path,
bfd_database_path=FLAGS.bfd_database_path,
uniclust30_database_path=FLAGS.uniclust30_database_path,
small_bfd_database_path=FLAGS.small_bfd_database_path,
uniprot_database_path=FLAGS.uniprot_database_path,
template_searcher=template_searcher,
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
use_precomputed_msas=FLAGS.use_precomputed_msas,
)
fasta_path = FLAGS.fasta_path
fasta_name = Path(fasta_path).stem
input_fasta_str = open(fasta_path).read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) > 1:
temp_names, temp_paths = divide_multi_chains(
fasta_name, FLAGS.output_dir, input_seqs, input_descs
)
fasta_names = temp_names
fasta_paths = temp_paths
else:
fasta_names = [fasta_name]
fasta_paths = [fasta_path]
# Check for duplicate FASTA file names.
if len(fasta_names) != len(set(fasta_names)):
raise ValueError("All FASTA paths must have a unique basename.")
# Predict structure for each of the sequences.
for i, fasta_path in enumerate(fasta_paths):
fasta_name = fasta_names[i]
generate_pkl_features(
fasta_path=fasta_path,
fasta_name=fasta_name,
output_dir_base=FLAGS.output_dir,
data_pipeline=data_pipeline,
use_uniprot=FLAGS.use_uniprot,
)
if __name__ == "__main__":
flags.mark_flags_as_required(
[
"fasta_path",
"output_dir",
"uniref90_database_path",
"mgnify_database_path",
"template_mmcif_dir",
"max_template_date",
"obsolete_pdbs_path",
]
)
app.run(main)
import argparse
import gzip
import logging
import math
import numpy as np
import os
import time
import torch
import json
import pickle
from unifold.config import model_config
from unifold.modules.alphafold import AlphaFold
from unifold.data import residue_constants, protein
from unifold.dataset import load_and_process, UnifoldDataset
from unicore.utils import (
tensor_tree_map,
)
def get_device_mem(device):
if device != "cpu" and torch.cuda.is_available():
cur_device = torch.cuda.current_device()
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
return total_memory_in_GB
else:
return 40
def automatic_chunk_size(seq_len, device, is_bf16):
total_mem_in_GB = get_device_mem(device)
factor = math.sqrt(total_mem_in_GB/40.0*(0.55 * is_bf16 + 0.45))*0.95
if seq_len < int(1024*factor):
chunk_size = 256
block_size = None
elif seq_len < int(2048*factor):
chunk_size = 128
block_size = None
elif seq_len < int(3072*factor):
chunk_size = 64
block_size = None
elif seq_len < int(4096*factor):
chunk_size = 32
block_size = 512
else:
chunk_size = 4
block_size = 256
return chunk_size, block_size
def load_feature_for_one_target(
config, data_folder, seed=0, is_multimer=False, use_uniprot=False
):
if not is_multimer:
uniprot_msa_dir = None
sequence_ids = ["A"]
if use_uniprot:
uniprot_msa_dir = data_folder
else:
uniprot_msa_dir = data_folder
sequence_ids = open(os.path.join(data_folder, "chains.txt")).readline().split()
batch, _ = load_and_process(
config=config.data,
mode="predict",
seed=seed,
batch_idx=None,
data_idx=0,
is_distillation=False,
sequence_ids=sequence_ids,
monomer_feature_dir=data_folder,
uniprot_msa_dir=uniprot_msa_dir,
is_monomer=(not is_multimer),
)
batch = UnifoldDataset.collater([batch])
return batch
def main(args):
config = model_config(args.model_name)
config.data.common.max_recycling_iters = args.max_recycling_iters
config.globals.max_recycling_iters = args.max_recycling_iters
config.data.predict.num_ensembles = args.num_ensembles
is_multimer = config.model.is_multimer
if args.sample_templates:
# enable template samples for diversity
config.data.predict.subsample_templates = True
model = AlphaFold(config)
print("start to load params {}".format(args.param_path))
state_dict = torch.load(args.param_path)["ema"]["params"]
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to(args.model_device)
model.eval()
model.inference_mode()
if args.bf16:
model.bfloat16()
# data path is based on target_name
data_dir = os.path.join(args.data_dir, args.target_name)
output_dir = os.path.join(args.output_dir, args.target_name)
os.system("mkdir -p {}".format(output_dir))
cur_param_path_postfix = os.path.split(args.param_path)[-1]
name_postfix = ""
if args.sample_templates:
name_postfix += "_st"
if not is_multimer and args.use_uniprot:
name_postfix += "_uni"
if args.max_recycling_iters != 3:
name_postfix += "_r" + str(args.max_recycling_iters)
if args.num_ensembles != 2:
name_postfix += "_e" + str(args.num_ensembles)
print("start to predict {}".format(args.target_name))
plddts = {}
ptms = {}
for seed in range(args.times):
cur_seed = hash((args.data_random_seed, seed)) % 100000
batch = load_feature_for_one_target(
config,
data_dir,
cur_seed,
is_multimer=is_multimer,
use_uniprot=args.use_uniprot,
)
seq_len = batch["aatype"].shape[-1]
# faster prediction with large chunk/block size
chunk_size, block_size = automatic_chunk_size(
seq_len,
args.model_device,
args.bf16
)
model.globals.chunk_size = chunk_size
model.globals.block_size = block_size
with torch.no_grad():
batch = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in batch.items()
}
shapes = {k: v.shape for k, v in batch.items()}
print(shapes)
t = time.perf_counter()
raw_out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
def to_float(x):
if x.dtype == torch.bfloat16 or x.dtype == torch.half:
return x.float()
else:
return x
if not args.save_raw_output:
score = ["plddt", "ptm", "iptm", "iptm+ptm"]
out = {
k: v for k, v in raw_out.items()
if k.startswith("final_") or k in score
}
else:
out = raw_out
del raw_out
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)
batch = tensor_tree_map(to_float, batch)
out = tensor_tree_map(lambda t: t[0, ...], out)
out = tensor_tree_map(to_float, out)
batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
# TODO: , may need to reorder chains, based on entity_ids
cur_protein = protein.from_prediction(
features=batch, result=out, b_factors=plddt_b_factors
)
cur_save_name = (
f"{args.model_name}_{cur_param_path_postfix}_{cur_seed}{name_postfix}"
)
plddts[cur_save_name] = str(mean_plddt)
if is_multimer:
ptms[cur_save_name] = str(np.mean(out["iptm+ptm"]))
with open(os.path.join(output_dir, cur_save_name + '.pdb'), "w") as f:
f.write(protein.to_pdb(cur_protein))
if args.save_raw_output:
with gzip.open(os.path.join(output_dir, cur_save_name + '_outputs.pkl.gz'), 'wb') as f:
pickle.dump(out, f)
del out
print("plddts", plddts)
score_name = f"{args.model_name}_{cur_param_path_postfix}_{args.data_random_seed}_{args.times}{name_postfix}"
plddt_fname = score_name + "_plddt.json"
json.dump(plddts, open(os.path.join(output_dir, plddt_fname), "w"), indent=4)
if ptms:
print("ptms", ptms)
ptm_fname = score_name + "_ptm.json"
json.dump(ptms, open(os.path.join(output_dir, ptm_fname), "w"), indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_device",
type=str,
default="cuda:0",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")""",
)
parser.add_argument(
"--model_name",
type=str,
default="model_2",
)
parser.add_argument(
"--param_path", type=str, default=None, help="Path to model parameters."
)
parser.add_argument(
"--data_random_seed",
type=int,
default=42,
)
parser.add_argument(
"--data_dir",
type=str,
default="",
)
parser.add_argument(
"--target_name",
type=str,
default="",
)
parser.add_argument(
"--output_dir",
type=str,
default="",
)
parser.add_argument(
"--times",
type=int,
default=3,
)
parser.add_argument(
"--max_recycling_iters",
type=int,
default=3,
)
parser.add_argument(
"--num_ensembles",
type=int,
default=2,
)
parser.add_argument("--sample_templates", action="store_true")
parser.add_argument("--use_uniprot", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--save_raw_output", action="store_true")
args = parser.parse_args()
if args.model_device == "cpu" and torch.cuda.is_available():
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
import argparse
import gzip
import logging
import numpy as np
import os
import pathlib
import time
import torch
import pickle
from unifold.data import residue_constants, protein
from unifold.dataset import UnifoldDataset
from unicore.utils import (
tensor_tree_map,
)
from unifold.symmetry import (
UFSymmetry,
load_and_process_symmetry,
uf_symmetry_config,
assembly_from_prediction,
)
from unifold.inference import (
automatic_chunk_size,
)
def load_feature_for_one_target(
config, data_folder, symmetry, seed=0, is_multimer=False, use_uniprot=False
):
if not is_multimer:
uniprot_msa_dir = None
sequence_ids = ["A"]
if use_uniprot:
uniprot_msa_dir = data_folder
else:
uniprot_msa_dir = data_folder
sequence_ids = open(os.path.join(data_folder, "chains.txt")).readline().split()
batch, _ = load_and_process_symmetry(
config=config.data,
mode="predict",
seed=seed,
batch_idx=None,
data_idx=0,
is_distillation=False,
symmetry=symmetry,
sequence_ids=sequence_ids,
monomer_feature_dir=data_folder,
uniprot_msa_dir=uniprot_msa_dir,
is_monomer=(not is_multimer),
)
batch = UnifoldDataset.collater([batch])
return batch
def main(args):
config = uf_symmetry_config()
config.data.common.max_recycling_iters = args.max_recycling_iters
config.globals.max_recycling_iters = args.max_recycling_iters
config.data.predict.num_ensembles = args.num_ensembles
is_multimer = config.model.is_multimer
if args.sample_templates:
# enable template samples for diversity
config.data.predict.subsample_templates = True
# faster prediction with large chunk
config.globals.chunk_size = 128
model = UFSymmetry(config)
print("start to load params {}".format(args.param_path))
state_dict = torch.load(args.param_path)["ema"]["params"]
state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to(args.model_device)
model.eval()
model.inference_mode()
if args.bf16:
model.bfloat16()
# data path is based on target_name
data_dir = os.path.join(args.data_dir, args.target_name)
output_dir = os.path.join(args.output_dir, args.target_name)
os.system("mkdir -p {}".format(output_dir))
param_name = pathlib.Path(args.param_path).stem
name_suffix = ""
if args.sample_templates:
name_suffix += "_st"
if not is_multimer and args.use_uniprot:
name_suffix += "_uni"
if args.max_recycling_iters != 3:
name_suffix += "_r" + str(args.max_recycling_iters)
if args.num_ensembles != 2:
name_suffix += "_e" + str(args.num_ensembles)
symmetry = args.symmetry
if symmetry[0] != 'C':
raise NotImplementedError(f"symmetry {symmetry} is not supported currently.")
print("start to predict {}".format(args.target_name))
for seed in range(args.times):
cur_seed = hash((args.data_random_seed, seed)) % 100000
batch = load_feature_for_one_target(
config,
data_dir,
args.symmetry,
cur_seed,
is_multimer=is_multimer,
use_uniprot=args.use_uniprot,
)
seq_len = batch["aatype"].shape[-1]
# faster prediction with large chunk/block size
chunk_size, block_size = automatic_chunk_size(
seq_len,
args.model_device,
args.bf16
)
model.globals.chunk_size = chunk_size
model.globals.block_size = block_size
with torch.no_grad():
batch = {
k: torch.as_tensor(v, device=args.model_device)
for k, v in batch.items()
}
shapes = {k: v.shape for k, v in batch.items()}
print(shapes)
t = time.perf_counter()
raw_out = model(batch, expand=True) # when expand, output assembly.
print(f"Inference time: {time.perf_counter() - t}")
def to_float(x):
if x.dtype == torch.bfloat16 or x.dtype == torch.half:
return x.float()
else:
return x
out = raw_out
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)
batch = tensor_tree_map(to_float, batch)
out = tensor_tree_map(lambda t: t[0, ...], out)
out = tensor_tree_map(to_float, out)
batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1
)
plddt_b_factors_assembly = np.repeat(
plddt_b_factors, batch["symmetry_opers"].shape[0], axis=-2)
cur_assembly = assembly_from_prediction(
result=out, b_factors=plddt_b_factors_assembly
)
cur_save_name = (
f"ufsymm_{param_name}_{cur_seed}{name_suffix}"
)
with open(os.path.join(output_dir, cur_save_name + '.pdb'), "w") as f:
f.write(protein.to_pdb(cur_assembly))
if args.save_raw_output:
out = {
k: v for k, v in out.items()
if k.startswith("final_") or k.startswith("expand_final_") or k == "plddt"
}
with gzip.open(os.path.join(output_dir, cur_save_name + '_outputs.pkl.gz'), 'wb') as f:
pickle.dump(out, f)
del out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_device",
type=str,
default="cuda:0",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")""",
)
# Fixed model name as `uf_symmetry` so no need for --model_name
parser.add_argument(
"--param_path", type=str, default=None, help="Path to model parameters."
)
parser.add_argument(
"--data_random_seed",
type=int,
default=42,
)
parser.add_argument(
"--data_dir",
type=str,
default="",
)
parser.add_argument(
"--target_name",
type=str,
default="",
)
parser.add_argument(
"--symmetry",
type=str,
default="C1",
)
parser.add_argument(
"--output_dir",
type=str,
default="",
)
parser.add_argument(
"--times",
type=int,
default=3,
)
parser.add_argument(
"--max_recycling_iters",
type=int,
default=3,
)
parser.add_argument(
"--num_ensembles",
type=int,
default=2,
)
parser.add_argument("--sample_templates", action="store_true")
parser.add_argument("--use_uniprot", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--save_raw_output", action="store_true")
args = parser.parse_args()
if args.model_device == "cpu" and torch.cuda.is_available():
logging.warning(
"""The model is being run on CPU. Consider specifying
--model_device for better performance"""
)
main(args)
import logging
import torch
from unicore import metrics
from unicore.utils import tensor_tree_map
from unicore.losses import UnicoreLoss, register_loss
from unicore.data import data_utils
from unifold.losses.geometry import compute_renamed_ground_truth, compute_metric
from unifold.losses.violation import find_structural_violations, violation_loss
from unifold.losses.fape import fape_loss
from unifold.losses.auxillary import (
chain_centre_mass_loss,
distogram_loss,
experimentally_resolved_loss,
masked_msa_loss,
pae_loss,
plddt_loss,
repr_norm_loss,
masked_msa_loss,
supervised_chi_loss,
)
from unifold.losses.chain_align import multi_chain_perm_align
@register_loss("af2")
class AlphafoldLoss(UnicoreLoss):
def __init__(self, task):
super().__init__(task)
def forward(self, model, batch, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
# return config in model.
out, config = model(batch)
num_recycling = batch["msa_feat"].shape[0]
# remove recyling dim
batch = tensor_tree_map(lambda t: t[-1, ...], batch)
loss, sample_size, logging_output = self.loss(out, batch, config)
logging_output["num_recycling"] = num_recycling
return loss, sample_size, logging_output
def loss(self, out, batch, config):
if "violation" not in out.keys() and config.violation.weight:
out["violation"] = find_structural_violations(
batch, out["sm"]["positions"], **config.violation)
if "renamed_atom14_gt_positions" not in out.keys():
batch.update(
compute_renamed_ground_truth(batch, out["sm"]["positions"]))
loss_dict = {}
loss_fns = {
"chain_centre_mass": lambda: chain_centre_mass_loss(
pred_atom_positions=out["final_atom_positions"],
true_atom_positions=batch["all_atom_positions"],
atom_mask=batch["all_atom_mask"],
asym_id=batch["asym_id"],
**config.chain_centre_mass,
loss_dict=loss_dict,
),
"distogram": lambda: distogram_loss(
logits=out["distogram_logits"],
pseudo_beta=batch["pseudo_beta"],
pseudo_beta_mask=batch["pseudo_beta_mask"],
**config.distogram,
loss_dict=loss_dict,
),
"experimentally_resolved": lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved_logits"],
atom37_atom_exists=batch["atom37_atom_exists"],
all_atom_mask=batch["all_atom_mask"],
resolution=batch["resolution"],
**config.experimentally_resolved,
loss_dict=loss_dict,
),
"fape": lambda: fape_loss(
out,
batch,
config.fape,
loss_dict=loss_dict,
),
"masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
true_msa=batch["true_msa"],
bert_mask=batch["bert_mask"],
loss_dict=loss_dict,
),
"pae": lambda: pae_loss(
logits=out["pae_logits"],
pred_frame_tensor=out["pred_frame_tensor"],
true_frame_tensor=batch["true_frame_tensor"],
frame_mask=batch["frame_mask"],
resolution=batch["resolution"],
**config.pae,
loss_dict=loss_dict,
),
"plddt": lambda: plddt_loss(
logits=out["plddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
all_atom_positions=batch["all_atom_positions"],
all_atom_mask=batch["all_atom_mask"],
resolution=batch["resolution"],
**config.plddt,
loss_dict=loss_dict,
),
"repr_norm": lambda: repr_norm_loss(
out["delta_msa"],
out["delta_pair"],
out["msa_norm_mask"],
batch["pseudo_beta_mask"],
**config.repr_norm,
loss_dict=loss_dict,
),
"supervised_chi": lambda: supervised_chi_loss(
pred_angles_sin_cos=out["sm"]["angles"],
pred_unnormed_angles_sin_cos=out["sm"]["unnormalized_angles"],
true_angles_sin_cos=batch["chi_angles_sin_cos"],
aatype=batch["aatype"],
seq_mask=batch["seq_mask"],
chi_mask=batch["chi_mask"],
**config.supervised_chi,
loss_dict=loss_dict,
),
"violation": lambda: violation_loss(
out["violation"],
loss_dict=loss_dict,
bond_angle_loss_weight=config.violation.bond_angle_loss_weight,
),
}
cum_loss = 0
bsz = batch["seq_mask"].shape[0]
with torch.no_grad():
seq_len = torch.sum(batch["seq_mask"].float(), dim=-1)
seq_length_weight = seq_len**0.5
assert (
len(seq_length_weight.shape) == 1 and seq_length_weight.shape[0] == bsz
), seq_length_weight.shape
for loss_name, loss_fn in loss_fns.items():
weight = config[loss_name].weight
if weight > 0.:
loss = loss_fn()
# always use float type for loss
assert loss.dtype == torch.float, loss.dtype
assert len(loss.shape) == 1 and loss.shape[0] == bsz, loss.shape
if any(torch.isnan(loss)) or any(torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0.0, requires_grad=True)
cum_loss = cum_loss + weight * loss
for key in loss_dict:
loss_dict[key] = float((loss_dict[key]).mean())
loss = (cum_loss * seq_length_weight).mean()
logging_output = loss_dict
# sample size fix to 1, so the loss (and gradients) will be averaged on all workers.
sample_size = 1
logging_output["loss"] = loss.data
logging_output["bsz"] = bsz
logging_output["sample_size"] = sample_size
logging_output["seq_len"] = seq_len
# logging_output["num_recycling"] = num_recycling
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs, split="valid") -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=4)
for key in logging_outputs[0]:
if key in ["sample_size", "bsz"]:
continue
loss_sum = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(key, loss_sum / sample_size, sample_size, round=4)
@staticmethod
def logging_outputs_can_be_summed(is_train) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
@register_loss("afm")
class AlphafoldMultimerLoss(AlphafoldLoss):
def forward(self, model, batch, reduce=True):
features, labels = batch
assert isinstance(features, dict)
# return config in model.
out, config = model(features)
num_recycling = features["msa_feat"].shape[0]
# remove recycling dim
features = tensor_tree_map(lambda t: t[-1, ...], features)
# perform multi-chain permutation alignment.
if labels:
with torch.no_grad():
batch_size = out["final_atom_positions"].shape[0]
new_labels = []
for batch_idx in range(batch_size):
cur_out = {
k: out[k][batch_idx]
for k in out
if k in {"final_atom_positions", "final_atom_mask"}
}
cur_feature = {k: features[k][batch_idx] for k in features}
cur_label = labels[batch_idx]
cur_new_labels = multi_chain_perm_align(
cur_out, cur_feature, cur_label
)
new_labels.append(cur_new_labels)
new_labels = data_utils.collate_dict(new_labels, dim=0)
# check for consistency of label and feature.
assert (new_labels["aatype"] == features["aatype"]).all()
features.update(new_labels)
loss, sample_size, logging_output = self.loss(out, features, config)
logging_output["num_recycling"] = num_recycling
return loss, sample_size, logging_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment