"vscode:/vscode.git/clone" did not exist on "cfb4c19cac1a4e2358edc02a6baff89a1e4377c0"
Commit 4bd1b4d5 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Work on multimer continues

parent 54164fe8
from . import model from . import model
from . import utils from . import utils
from . import data
from . import np from . import np
from . import resources from . import resources
......
...@@ -75,7 +75,17 @@ def model_config(name, train=False, low_prec=False): ...@@ -75,7 +75,17 @@ def model_config(name, train=False, low_prec=False):
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.model.update(multimer_model_config_update) c.globals.is_multimer = True
for k,v in multimer_model_config_update.items():
c.model[k] = v
c.data.common.unsupervised_features.extend([
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
])
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
...@@ -276,6 +286,7 @@ config = mlc.ConfigDict( ...@@ -276,6 +286,7 @@ config = mlc.ConfigDict(
"c_e": c_e, "c_e": c_e,
"c_s": c_s, "c_s": c_s,
"eps": eps, "eps": eps,
"is_multimer": False,
}, },
"model": { "model": {
"_mask_trans": False, "_mask_trans": False,
...@@ -335,6 +346,7 @@ config = mlc.ConfigDict( ...@@ -335,6 +346,7 @@ config = mlc.ConfigDict(
"eps": eps, # 1e-6, "eps": eps, # 1e-6,
"enabled": templates_enabled, "enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles, "embed_angles": embed_template_torsion_angles,
"use_unit_vector": False,
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { "extra_msa_embedder": {
...@@ -496,10 +508,76 @@ config = mlc.ConfigDict( ...@@ -496,10 +508,76 @@ config = mlc.ConfigDict(
} }
) )
multimer_model_config_update = mlc.ConfigDict( multimer_model_config_update = {
"relative_encoding": { "input_embedder": {
"enabled": True, "tf_dim": 21,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2, "max_relative_chain": 2,
"max_relative_idx": 32, "max_relative_idx": 32,
} "use_chain_relative": True,
) },
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
},
"c_t": c_t,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": tm_enabled,
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
}
This diff is collapsed.
...@@ -428,10 +428,16 @@ def make_hhblits_profile(protein): ...@@ -428,10 +428,16 @@ def make_hhblits_profile(protein):
@curry1 @curry1
def make_masked_msa(protein, config, replace_fraction): def make_masked_msa(protein, config, replace_fraction, seed):
"""Create data for BERT on raw MSA.""" """Create data for BERT on raw MSA."""
device = protein["msa"].device
# Add a random amino acid uniformly. # Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) random_aa = torch.tensor(
[0.05] * 20 + [0.0, 0.0],
dtype=torch.float32,
device=device
)
categorical_probs = ( categorical_probs = (
config.uniform_prob * random_aa config.uniform_prob * random_aa
...@@ -449,11 +455,17 @@ def make_masked_msa(protein, config, replace_fraction): ...@@ -449,11 +455,17 @@ def make_masked_msa(protein, config, replace_fraction):
) )
assert mask_prob >= 0.0 assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad( categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob categorical_probs, pad_shapes, value=mask_prob,
) )
sh = protein["msa"].shape sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
g = torch.Generator(device=protein["msa"].device)
if seed is not None:
g.manual_seed(seed)
sample = torch.rand(sh, device=device, generator=g)
mask_position = sample < replace_fraction
bert_msa = shaped_categorical(categorical_probs) bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein["msa"]) bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
......
from typing import Sequence
import torch
from openfold.data.data_transforms import curry1
from openfold.utils.tensor_utils import masked_mean
def gumbel_noise(
shape: Sequence[int],
device: torch.device,
eps=1e-6,
generator=None,
) -> torch.Tensor:
"""Generate Gumbel Noise of given Shape.
This generates samples from Gumbel(0, 1).
Args:
shape: Shape of noise to return.
Returns:
Gumbel noise of given shape.
"""
uniform_noise = torch.rand(
shape, dtype=torch.float32, device=device, generator=generator
)
gumbel = -torch.log(-torch.log(uniform_noise + eps) + eps)
return gumbel
def gumbel_max_sample(logits: torch.Tensor, generator=None) -> torch.Tensor:
"""Samples from a probability distribution given by 'logits'.
This uses Gumbel-max trick to implement the sampling in an efficient manner.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.nn.functional.one_hot(
torch.argmax(logits + z, dim=-1),
logits.shape[-1],
)
def gumbel_argsort_sample_idx(
logits: torch.Tensor,
generator=None
) -> torch.Tensor:
"""Samples with replacement from a distribution given by 'logits'.
This uses Gumbel trick to implement the sampling an efficient manner. For a
distribution over k items this samples k times without replacement, so this
is effectively sampling a random permutation with probabilities over the
permutations derived from the logprobs.
Args:
logits: Logarithm of probabilities to sample from, probabilities can be
unnormalized.
Returns:
Sample from logprobs in one-hot form.
"""
z = gumbel_noise(logits.shape, device=logits.device, generator=generator)
return torch.argsort(logits + z, dim=-1, descending=True)
@curry1
def make_masked_msa(batch, config, replace_fraction, seed, eps=1e-6):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.Tensor(
[0.05] * 20 + [0., 0.],
device=batch['msa'].device
)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * batch['msa_profile'] +
config.same_prob * torch.nn.functional.one_hot(batch['msa'], 22)
)
# Put all remaining probability on [MASK] which is a new column.
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
categorical_probs = torch.nn.functional.pad(
categorical_probs, [0,1], value=mask_prob
)
sh = batch['msa'].shape
mask_position = torch.rand(sh, device=batch['msa'].device) < replace_fraction
mask_position *= batch['msa_mask'].to(mask_position.dtype)
logits = torch.log(categorical_probs + eps)
g = torch.Generator(device=batch["msa"].device)
if seed is not None:
g.manual_seed(seed)
bert_msa = gumbel_max_sample(logits, generator=g)
bert_msa = torch.where(
mask_position,
torch.argmax(bert_msa, dim=-1),
batch['msa']
)
bert_msa *= batch['msa_mask'].to(bert_msa.dtype)
# Mix real and masked MSA.
if 'bert_mask' in batch:
batch['bert_mask'] *= mask_position.to(torch.float32)
else:
batch['bert_mask'] = mask_position.to(torch.float32)
batch['true_msa'] = batch['msa']
batch['msa'] = bert_msa
return batch
@curry1
def nearest_neighbor_clusters(batch, gap_agreement_weight=0.):
"""Assign each extra MSA sequence to its nearest neighbor in sampled MSA."""
device = batch["msa_mask"].device
# Determine how much weight we assign to each agreement. In theory, we could
# use a full blosum matrix here, but right now let's just down-weight gap
# agreement because it could be spurious.
# Never put weight on agreeing on BERT mask.
weights = torch.Tensor(
[1.] * 21 + [gap_agreement_weight] + [0.],
device=device,
)
msa_mask = batch['msa_mask']
msa_one_hot = torch.nn.functional.one_hot(batch['msa'], 23)
extra_mask = batch['extra_msa_mask']
extra_one_hot = torch.nn.functional.one_hot(batch['extra_msa'], 23)
msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot
extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot
agreement = torch.einsum(
'mrc, nrc->nm',
extra_one_hot_masked,
weights * msa_one_hot_masked
)
cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0)
cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask)
cluster_count = torch.sum(cluster_assignment, dim=-1)
cluster_count += 1. # We always include the sequence itself.
msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked)
msa_sum += msa_one_hot_masked
cluster_profile = msa_sum / cluster_count[:, None, None]
extra_deletion_matrix = batch['extra_deletion_matrix']
deletion_matrix = batch['deletion_matrix']
del_sum = torch.einsum(
'nm, mc->nc',
cluster_assignment,
extra_mask * extra_deletion_matrix
)
del_sum += deletion_matrix # Original sequence.
cluster_deletion_mean = del_sum / cluster_count[:, None]
batch['cluster_profile'] = cluster_profile
batch['cluster_deletion_mean'] = cluster_deletion_mean
return batch
def create_target_feat(batch):
"""Create the target features"""
batch["target_feat"] = torch.nn.functional.one_hot(
batch["aatype"], 21
).to(torch.float32)
return batch
def create_msa_feat(batch):
"""Create and concatenate MSA features."""
device = batch["msa"]
msa_1hot = torch.nn.functional.one_hot(batch['msa'], 23)
deletion_matrix = batch['deletion_matrix']
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
deletion_mean_value = (
torch.atan(
batch['cluster_deletion_mean'] / 3.) *
(2. / pi)
)[..., None]
msa_feat = torch.cat(
[
msa_1hot,
has_deletion,
deletion_value,
batch['cluster_profile'],
deletion_mean_value
],
dim=-1,
)
batch["msa_feat"] = msa_feat
return batch
def build_extra_msa_feat(batch):
"""Expand extra_msa into 1hot and concat with other extra msa features.
We do this as late as possible as the one_hot extra msa can be very large.
Args:
batch: a dictionary with the following keys:
* 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster
centre. Note - This isn't one-hotted.
* 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given
position.
num_extra_msa: Number of extra msa to use.
Returns:
Concatenated tensor of extra MSA features.
"""
# 23 = 20 amino acids + 'X' for unknown + gap + bert mask
extra_msa = batch['extra_msa']
deletion_matrix = batch['extra_deletion_matrix']
msa_1hot = torch.nn.functional.one_hot(extra_msa, 23)
has_deletion = torch.clamp(deletion_matrix, min=0., max=1.)[..., None]
pi = torch.acos(torch.zeros(1, device=deletion_matrix.device)) * 2
deletion_value = (
(torch.atan(deletion_matrix / 3.) * (2. / pi))[..., None]
)
extra_msa_mask = batch['extra_msa_mask']
catted = torch.cat([msa_1hot, has_deletion, deletion_value], dim=-1)
return catted
@curry1
def sample_msa(batch, max_seq, max_extra_msa_seq, seed, inf=1e6):
"""Sample MSA randomly, remaining sequences are stored as `extra_*`.
Args:
batch: batch to sample msa from.
max_seq: number of sequences to sample.
Returns:
Protein with sampled msa.
"""
g = torch.Generator(device=batch["msa"].device)
if seed is not None:
g.manual_seed(seed)
# Sample uniformly among sequences with at least one non-masked position.
logits = (torch.clamp(torch.sum(batch['msa_mask'], dim=-1), 0., 1.) - 1.) * inf
# The cluster_bias_mask can be used to preserve the first row (target
# sequence) for each chain, for example.
if 'cluster_bias_mask' not in batch:
cluster_bias_mask = torch.nn.functional.pad(
batch['msa'].new_zeros(batch['msa'].shape[0] - 1),
(1, 0),
value=1.
)
else:
cluster_bias_mask = batch['cluster_bias_mask']
logits += cluster_bias_mask * inf
index_order = gumbel_argsort_sample_idx(logits, generator=g)
sel_idx = index_order[:max_seq]
extra_idx = index_order[max_seq:][:max_extra_msa_seq]
for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']:
if k in batch:
batch['extra_' + k] = batch[k][extra_idx]
batch[k] = batch[k][sel_idx]
return batch
def make_msa_profile(batch):
"""Compute the MSA profile."""
# Compute the profile for every residue (over all MSA sequences).
batch["msa_profile"] = masked_mean(
batch['msa_mask'][..., None],
torch.nn.functional.one_hot(batch['msa'], 22),
dim=-3,
)
return batch
...@@ -20,7 +20,7 @@ import ml_collections ...@@ -20,7 +20,7 @@ import ml_collections
import numpy as np import numpy as np
import torch import torch
from openfold.data import input_pipeline from openfold.data import input_pipeline, input_pipeline_multimer
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
...@@ -73,8 +73,10 @@ def np_example_to_features( ...@@ -73,8 +73,10 @@ def np_example_to_features(
np_example: FeatureDict, np_example: FeatureDict,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str, mode: str,
is_multimer: bool = False
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example["seq_length"][0]) num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
...@@ -87,11 +89,18 @@ def np_example_to_features( ...@@ -87,11 +89,18 @@ def np_example_to_features(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
if(not is_multimer):
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, tensor_dict,
cfg.common, cfg.common,
cfg[mode], cfg[mode],
) )
else:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
...@@ -107,9 +116,14 @@ class FeaturePipeline: ...@@ -107,9 +116,14 @@ class FeaturePipeline:
self, self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = "train", mode: str = "train",
is_multimer: bool = False,
) -> FeatureDict: ) -> FeatureDict:
if(is_multimer and mode != "predict"):
raise ValueError("Multimer mode is not currently trainable")
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
config=self.config, config=self.config,
mode=mode, mode=mode,
is_multimer=is_multimer,
) )
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Feature processing logic for multimer data pipeline.""" """Feature processing logic for multimer data pipeline."""
from typing import Iterable, MutableMapping, List from typing import Iterable, MutableMapping, List, Mapping
from openfold.data import msa_pairing from openfold.data import msa_pairing
from openfold.np import residue_constants from openfold.np import residue_constants
...@@ -49,13 +49,11 @@ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool: ...@@ -49,13 +49,11 @@ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
def pair_and_merge( def pair_and_merge(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]], all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]],
is_prokaryote: bool) -> Mapping[str, np.ndarray]: ) -> Mapping[str, np.ndarray]:
"""Runs processing on features to augment, pair and merge. """Runs processing on features to augment, pair and merge.
Args: Args:
all_chain_features: A MutableMap of dictionaries of features for each chain. all_chain_features: A MutableMap of dictionaries of features for each chain.
is_prokaryote: Whether the target complex is from a prokaryotic or
eukaryotic organism.
Returns: Returns:
A dictionary of features. A dictionary of features.
...@@ -69,7 +67,8 @@ def pair_and_merge( ...@@ -69,7 +67,8 @@ def pair_and_merge(
if pair_msa_sequences: if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features( np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list, prokaryotic=is_prokaryote) chains=np_chains_list
)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list) np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list)
np_chains_list = crop_chains( np_chains_list = crop_chains(
np_chains_list, np_chains_list,
...@@ -175,6 +174,7 @@ def process_final( ...@@ -175,6 +174,7 @@ def process_final(
np_example = _make_seq_mask(np_example) np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example) np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example) np_example = _filter_features(np_example)
return np_example return np_example
...@@ -210,19 +210,23 @@ def _filter_features( ...@@ -210,19 +210,23 @@ def _filter_features(
def process_unmerged_features( def process_unmerged_features(
all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]): all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]
):
"""Postprocessing stage for per-chain features before merging.""" """Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features) num_chains = len(all_chain_features)
for chain_features in all_chain_features.values(): for chain_features in all_chain_features.values():
# Convert deletion matrices to float. # Convert deletion matrices to float.
chain_features['deletion_matrix'] = np.asarray( chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32) chain_features.pop('deletion_matrix_int'), dtype=np.float32
)
if 'deletion_matrix_int_all_seq' in chain_features: if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray( chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32) chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32
)
chain_features['deletion_mean'] = np.mean( chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0) chain_features['deletion_matrix'], axis=0
)
# Add all_atom_mask and dummy all_atom_positions based on aatype. # Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
......
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
from openfold.data import (
data_transforms,
data_transforms_multimer,
)
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
]
if(common_cfg.use_templates):
transforms.extend([
data_transforms.make_pseudo_beta("template_"),
])
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
msa_seed = None
if(not common_cfg.resample_msa_in_recycling):
msa_seed = ensemble_seed
transforms.append(
data_transforms_multimer.sample_msa(
max_msa_clusters,
max_extra_msa,
seed=msa_seed,
)
)
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms_multimer.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
seed=(msa_seed + 1) if msa_seed else None,
)
)
transforms.append(data_transforms_multimer.nearest_neighbor_clusters())
transforms.append(data_transforms_multimer.create_msa_feat)
return transforms
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
...@@ -48,7 +48,6 @@ _UNIPROT_PATTERN = re.compile( ...@@ -48,7 +48,6 @@ _UNIPROT_PATTERN = re.compile(
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Identifiers: class Identifiers:
uniprot_accession_id: str = ''
species_id: str = '' species_id: str = ''
...@@ -69,8 +68,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: ...@@ -69,8 +68,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches: if matches:
return Identifiers( return Identifiers(
uniprot_accession_id=matches.group('AccessionIdentifier'), species_id=matches.group('SpeciesIdentifier')
species_id=matches.group('SpeciesIdentifier')) )
return Identifiers() return Identifiers()
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import collections import collections
import functools import functools
import string import string
from typing import Any, Dict, Iterable, List, Sequence from typing import Any, Dict, Iterable, List, Sequence, Mapping
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -27,12 +27,6 @@ from openfold.np import residue_constants ...@@ -27,12 +27,6 @@ from openfold.np import residue_constants
# TODO: This stuff should probably also be in a config # TODO: This stuff should probably also be in a config
ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)}
ALPHANUM_ACCESSION_ID_MAP = {
chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits)
} # A-Z,0-9
NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9
MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-') MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5 SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9 SEQUENCE_SIMILARITY_CUTOFF = 0.9
...@@ -61,14 +55,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length') ...@@ -61,14 +55,11 @@ CHAIN_FEATURES = ('num_alignments', 'seq_length')
def create_paired_features( def create_paired_features(
chains: Iterable[Mapping[str, np.ndarray]], chains: Iterable[Mapping[str, np.ndarray]],
prokaryotic: bool, ) -> List[Mapping[str, np.ndarray]]:
) -> List[Mapping[str, np.ndarray]]:
"""Returns the original chains with paired NUM_SEQ features. """Returns the original chains with paired NUM_SEQ features.
Args: Args:
chains: A list of feature dictionaries for each chain. chains: A list of feature dictionaries for each chain.
prokaryotic: Whether the target complex is from a prokaryotic organism.
Used to determine the distance metric for pairing.
Returns: Returns:
A list of feature dictionaries with sequence features including only A list of feature dictionaries with sequence features including only
...@@ -81,8 +72,7 @@ def create_paired_features( ...@@ -81,8 +72,7 @@ def create_paired_features(
return chains return chains
else: else:
updated_chains = [] updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences( paired_chains_to_paired_row_indices = pair_sequences(chains)
chains, prokaryotic)
paired_rows = reorder_paired_rows( paired_rows = reorder_paired_rows(
paired_chains_to_paired_row_indices) paired_chains_to_paired_row_indices)
...@@ -117,8 +107,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: ...@@ -117,8 +107,7 @@ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
num_res = feature.shape[1] num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype) feature.dtype)
elif feature_name in ('msa_uniprot_accession_identifiers_all_seq', elif feature_name == 'msa_species_identifiers_all_seq':
'msa_species_identifiers_all_seq'):
padding = [b''] padding = [b'']
else: else:
return feature return feature
...@@ -136,11 +125,9 @@ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame: ...@@ -136,11 +125,9 @@ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
msa_df = pd.DataFrame({ msa_df = pd.DataFrame({
'msa_species_identifiers': 'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'], chain_features['msa_species_identifiers_all_seq'],
'msa_uniprot_accession_identifiers':
chain_features['msa_uniprot_accession_identifiers_all_seq'],
'msa_row': 'msa_row':
np.arange(len( np.arange(len(
chain_features['msa_uniprot_accession_identifiers_all_seq'])), chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity': per_seq_similarity, 'msa_similarity': per_seq_similarity,
'gap': per_seq_gap 'gap': per_seq_gap
}) })
...@@ -155,139 +142,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: ...@@ -155,139 +142,6 @@ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
return species_lookup return species_lookup
@functools.lru_cache(maxsize=65536)
def encode_accession(accession_id: str) -> int:
"""Map accession codes to the serial order in which they were assigned."""
alpha = ALPHA_ACCESSION_ID_MAP # A-Z
alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9
num = NUM_ACCESSION_ID_MAP # 0-9
coding = 0
# This is based on the uniprot accession id format
# https://www.uniprot.org/help/accession_numbers
if accession_id[0] in {'O', 'P', 'Q'}:
bases = (alpha, num, alphanum, alphanum, alphanum, num)
elif len(accession_id) == 6:
bases = (alpha, num, alpha, alphanum, alphanum, num)
elif len(accession_id) == 10:
bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum,
alphanum, num)
product = 1
for place, base in zip(reversed(accession_id), reversed(bases)):
coding += base[place] * product
product *= len(base)
return coding
def _calc_id_diff(id_a: bytes, id_b: bytes) -> int:
return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode()))
def _find_all_accession_matches(accession_id_lists: List[List[bytes]],
diff_cutoff: int = 20
) -> List[List[Any]]:
"""Finds accession id matches across the chains based on their difference."""
all_accession_tuples = []
current_tuple = []
tokens_used_in_answer = set()
def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool:
return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple))
def _all_tokens_not_used_before() -> bool:
return all((s not in tokens_used_in_answer for s in current_tuple))
def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None:
if level == len(accession_id_lists) - 1:
if _all_tokens_not_used_before():
all_accession_tuples.append(list(current_tuple))
for s in current_tuple:
tokens_used_in_answer.add(s)
return
if level == -1:
new_list = accession_id_lists[level+1]
else:
new_list = [(_calc_id_diff(accession_id, s), s) for
s in accession_id_lists[level+1]]
new_list = sorted(new_list)
new_list = [s for d, s in new_list]
for s in new_list:
if (_matches_all_in_current_tuple(s, diff_cutoff) and
s not in tokens_used_in_answer):
current_tuple.append(s)
dfs(level + 1, s)
current_tuple.pop()
dfs(-1, '')
return all_accession_tuples
def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series:
matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id]
return matched_df.iloc[0]
def _match_rows_by_genetic_distance(
this_species_msa_dfs: List[pd.DataFrame],
cutoff: int = 20) -> List[List[int]]:
"""Finds MSA sequence pairings across chains within a genetic distance cutoff.
The genetic distance between two sequences is approximated by taking the
difference in their UniProt accession ids.
Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species. If species is missing for a chain, the
dataframe is set to None.
cutoff: the genetic distance cutoff.
Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
num_examples = len(this_species_msa_dfs) # N
accession_id_lists = [] # M
match_index_to_chain_index = {}
for chain_index, species_df in enumerate(this_species_msa_dfs):
if species_df is not None:
accession_id_lists.append(
list(species_df.msa_uniprot_accession_identifiers.values))
# Keep track of which of the this_species_msa_dfs are not None.
match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index
all_accession_id_matches = _find_all_accession_matches(
accession_id_lists, cutoff) # [k, M]
all_paired_msa_rows = [] # [k, N]
for accession_id_match in all_accession_id_matches:
paired_msa_rows = []
for match_index, accession_id in enumerate(accession_id_match):
# Map back to chain index.
chain_index = match_index_to_chain_index[match_index]
seq_series = _accession_row(
this_species_msa_dfs[chain_index], accession_id)
if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or
seq_series.gap > SEQUENCE_GAP_CUTOFF):
continue
else:
paired_msa_rows.append(seq_series.msa_row)
# If a sequence is skipped based on sequence similarity to the respective
# target sequence or a gap cuttoff, the lengths of accession_id_match and
# paired_msa_rows will be different. Skip this match.
if len(paired_msa_rows) == len(accession_id_match):
paired_and_non_paired_msa_rows = np.array([-1] * num_examples)
matched_chain_indices = list(match_index_to_chain_index.values())
paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows
all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows))
return all_paired_msa_rows
def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
) -> List[List[int]]: ) -> List[List[int]]:
"""Finds MSA sequence pairings across chains based on sequence similarity. """Finds MSA sequence pairings across chains based on sequence similarity.
...@@ -324,8 +178,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] ...@@ -324,8 +178,9 @@ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
return all_paired_msa_rows return all_paired_msa_rows
def pair_sequences(examples: List[Mapping[str, np.ndarray]], def pair_sequences(
prokaryotic: bool) -> Dict[int, np.ndarray]: examples: List[Mapping[str, np.ndarray]],
) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains.""" """Returns indices for paired MSA sequences across chains."""
num_examples = len(examples) num_examples = len(examples)
...@@ -367,22 +222,6 @@ def pair_sequences(examples: List[Mapping[str, np.ndarray]], ...@@ -367,22 +222,6 @@ def pair_sequences(examples: List[Mapping[str, np.ndarray]],
isinstance(species_df, pd.DataFrame)]) > 600): isinstance(species_df, pd.DataFrame)]) > 600):
continue continue
# In prokaryotes (and some eukaryotes), interacting genes are often
# co-located on the chromosome into operons. Because of that we can assume
# that if two proteins' intergenic distance is less than a threshold, they
# two proteins will form an an interacting pair.
# In most eukaryotes, a single protein's MSA can contain many paralogs.
# Two genes may interact even if they are not close by genomic distance.
# In case of eukaryotes, some methods pair MSA sequences using sequence
# similarity method.
# See Jinbo Xu's work:
# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28.
if prokaryotic:
paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs)
if not paired_msa_rows:
continue
else:
paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs) paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows) all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
...@@ -431,13 +270,19 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: ...@@ -431,13 +270,19 @@ def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
def _correct_post_merged_feats( def _correct_post_merged_feats(
np_example: Mapping[str, np.ndarray], np_example: Mapping[str, np.ndarray],
np_chains_list: Sequence[Mapping[str, np.ndarray]], np_chains_list: Sequence[Mapping[str, np.ndarray]],
pair_msa_sequences: bool) -> Mapping[str, np.ndarray]: pair_msa_sequences: bool
) -> Mapping[str, np.ndarray]:
"""Adds features that need to be computed/recomputed post merging.""" """Adds features that need to be computed/recomputed post merging."""
np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0], num_res = np_example['aatype'].shape[0]
dtype=np.int32) np_example['seq_length'] = np.asarray(
np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0], [num_res] * num_res,
dtype=np.int32) dtype=np.int32
)
np_example['num_alignments'] = np.asarray(
np_example['msa'].shape[0],
dtype=np.int32
)
if not pair_msa_sequences: if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the # Generate a bias that is 1 for the first row of every block in the
...@@ -449,29 +294,41 @@ def _correct_post_merged_feats( ...@@ -449,29 +294,41 @@ def _correct_post_merged_feats(
mask = np.zeros(chain['msa'].shape[0]) mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1 mask[0] = 1
cluster_bias_masks.append(mask) cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)
# Initialize Bert mask with masked out off diagonals. # Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) msa_masks = [
for x in np_chains_list] np.ones(x['msa'].shape, dtype=np.float32)
for x in np_chains_list
]
np_example['bert_mask'] = block_diag( np_example['bert_mask'] = block_diag(
*msa_masks, pad_value=0) *msa_masks, pad_value=0
)
else: else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1 np_example['cluster_bias_mask'][0] = 1
# Initialize Bert mask with masked out off diagonals. # Initialize Bert mask with masked out off diagonals.
msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for msa_masks = [
x in np_chains_list] np.ones(x['msa'].shape, dtype=np.float32) for
msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for x in np_chains_list
x in np_chains_list] ]
msa_masks_all_seq = [
np.ones(x['msa_all_seq'].shape, dtype=np.float32) for
x in np_chains_list
]
msa_mask_block_diag = block_diag( msa_mask_block_diag = block_diag(
*msa_masks, pad_value=0) *msa_masks, pad_value=0
)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate( np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0) [msa_mask_all_seq, msa_mask_block_diag],
axis=0
)
return np_example return np_example
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Functions for parsing various file formats.""" """Functions for parsing various file formats."""
import collections import collections
import dataclasses import dataclasses
import itertools
import re import re
import string import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
...@@ -29,8 +30,7 @@ class Msa: ...@@ -29,8 +30,7 @@ class Msa:
"""Class representing a parsed MSA file""" """Class representing a parsed MSA file"""
sequences: Sequence[str] sequences: Sequence[str]
deletion_matrix: DeletionMatrix deletion_matrix: DeletionMatrix
descriptions: Sequence[str] descriptions: Optional[Sequence[str]]
def __post_init__(self): def __post_init__(self):
if(not ( if(not (
...@@ -642,3 +642,20 @@ def parse_hmmsearch_a3m( ...@@ -642,3 +642,20 @@ def parse_hmmsearch_a3m(
hits.append(hit) hits.append(hit)
return hits return hits
def parse_hmmsearch_sto(
output_string: str,
input_sequence: str
) -> Sequence[TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = convert_stockholm_to_a3m(
output_string,
remove_first_row_gaps=False
)
template_hits = parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False
)
return template_hits
...@@ -220,13 +220,6 @@ def _assess_hhsearch_hit( ...@@ -220,13 +220,6 @@ def _assess_hhsearch_hit(
template_sequence = hit.hit_sequence.replace("-", "") template_sequence = hit.hit_sequence.replace("-", "")
length_ratio = float(len(template_sequence)) / len(query_sequence) length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
date = release_dates[hit_pdb_code.upper()] date = release_dates[hit_pdb_code.upper()]
raise DateError( raise DateError(
...@@ -240,6 +233,13 @@ def _assess_hhsearch_hit( ...@@ -240,6 +233,13 @@ def _assess_hhsearch_hit(
f"Align ratio: {align_ratio}." f"Align ratio: {align_ratio}."
) )
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if duplicate: if duplicate:
raise DuplicateError( raise DuplicateError(
"Template is an exact subsequence of query with large " "Template is an exact subsequence of query with large "
...@@ -770,7 +770,7 @@ def _prefilter_hit( ...@@ -770,7 +770,7 @@ def _prefilter_hit(
except PrefilterError as e: except PrefilterError as e:
hit_name = f"{hit_pdb_code}_{hit_chain_id}" hit_name = f"{hit_pdb_code}_{hit_chain_id}"
msg = f"hit {hit_name} did not pass prefilter: {str(e)}" msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg) logging.info(msg)
if strict_error_check and isinstance(e, (DateError, DuplicateError)): if strict_error_check and isinstance(e, (DateError, DuplicateError)):
# In strict mode we treat some prefilter cases as errors. # In strict mode we treat some prefilter cases as errors.
return PrefilterResult(valid=False, error=msg, warning=None) return PrefilterResult(valid=False, error=msg, warning=None)
...@@ -826,6 +826,7 @@ def _process_single_hit( ...@@ -826,6 +826,7 @@ def _process_single_hit(
query_sequence, query_sequence,
template_sequence, template_sequence,
) )
# Fail if we can't find the mmCIF file. # Fail if we can't find the mmCIF file.
cif_string = _read_file(cif_path) cif_string = _read_file(cif_path)
...@@ -968,7 +969,7 @@ class TemplateHitFeaturizer(abc.ABC): ...@@ -968,7 +969,7 @@ class TemplateHitFeaturizer(abc.ABC):
raise ValueError( raise ValueError(
"max_template_date must be set and have format YYYY-MM-DD." "max_template_date must be set and have format YYYY-MM-DD."
) )
self.max_hits = max_hits self._max_hits = max_hits
self._kalign_binary_path = kalign_binary_path self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check self._strict_error_check = strict_error_check
...@@ -997,33 +998,23 @@ class TemplateHitFeaturizer(abc.ABC): ...@@ -997,33 +998,23 @@ class TemplateHitFeaturizer(abc.ABC):
query_sequence: str, query_sequence: str,
hits: Sequence[parsers.TemplateHit] hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult: ) -> TemplateSearchResult:
""" Computes the templates for a given query sequence """
class HhsearchHitFeaturizer(TemplateHitFeaturizer): class HhsearchHitFeaturizer(TemplateHitFeaturizer):
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit], hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult: ) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above).""" """Computes the templates for given query sequence (more details above)."""
logging.info("Searching for template for: %s", query_pdb_code) logging.info("Searching for template for: %s", query_sequence)
template_features = {} template_features = {}
for template_feature_name in TEMPLATE_FEATURES: for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = [] template_features[template_feature_name] = []
# Always use a max_template_date. Set to query_release_date minus 60 days already_seen = set()
# if that's earlier.
template_cutoff_date = self._max_template_date
if query_release_date:
delta = datetime.timedelta(days=60)
if query_release_date - delta < template_cutoff_date:
template_cutoff_date = query_release_date - delta
assert template_cutoff_date < query_release_date
assert template_cutoff_date <= self._max_template_date
num_hits = 0
errors = [] errors = []
warnings = [] warnings = []
...@@ -1032,7 +1023,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1032,7 +1023,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
prefilter_result = _prefilter_hit( prefilter_result = _prefilter_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
hit=hit, hit=hit,
max_template_date=template_cutoff_date, max_template_date=self._max_template_date,
release_dates=self._release_dates, release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
...@@ -1057,17 +1048,16 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1057,17 +1048,16 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
for i in idx: for i in idx:
# We got all the templates we wanted, stop processing hits. # We got all the templates we wanted, stop processing hits.
if num_hits >= self.max_hits: if len(already_seen) >= self.max_hits:
break break
hit = filtered[i] hit = filtered[i]
result = _process_single_hit( result = _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
hit=hit, hit=hit,
mmcif_dir=self._mmcif_dir, mmcif_dir=self._mmcif_dir,
max_template_date=template_cutoff_date, max_template_date=self._max_template_date,
release_dates=self._release_dates, release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
...@@ -1091,8 +1081,10 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1091,8 +1081,10 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
result.warning, result.warning,
) )
else: else:
# Increment the hit counter, since we got features out of this hit. already_seen_key = result.features["template_sequence"]
num_hits += 1 if(already_seen_key in already_seen):
continue
already_seen.add(already_seen_key)
for k in template_features: for k in template_features:
template_features[k].append(result.features[k]) template_features[k].append(result.features[k])
...@@ -1118,6 +1110,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1118,6 +1110,8 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
query_sequence: str, query_sequence: str,
hits: Sequence[parsers.TemplateHit] hits: Sequence[parsers.TemplateHit]
) -> TemplateSearchResult: ) -> TemplateSearchResult:
logging.info("Searching for template for: %s", query_sequence)
template_features = {} template_features = {}
for template_feature_name in TEMPLATE_FEATURES: for template_feature_name in TEMPLATE_FEATURES:
template_features[template_feature_name] = [] template_features[template_feature_name] = []
...@@ -1126,15 +1120,43 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer): ...@@ -1126,15 +1120,43 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
errors = [] errors = []
warnings = [] warnings = []
if not hits or hits[0].sum_probs is None: # DISCREPANCY: This filtering scheme that saves time
sorted_hits = hits filtered = []
else: for hit in hits:
sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True) prefilter_result = _prefilter_hit(
query_sequence=query_sequence,
hit=hit,
max_template_date=self._max_template_date,
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
)
if prefilter_result.error:
errors.append(prefilter_result.error)
if prefilter_result.warning:
warnings.append(prefilter_result.warning)
if prefilter_result.valid:
filtered.append(hit)
for hit in sorted_hits: filtered = list(
sorted(
filtered, key=lambda x: x.sum_probs if x.sum_probs else 0., reverse=True
)
)
idx = list(range(len(filtered)))
if(self._shuffle_top_k_prefiltered):
stk = self._shuffle_top_k_prefiltered
idx[:stk] = np.random.permutation(idx[:stk])
for i in idx:
if(len(already_seen) >= self._max_hits): if(len(already_seen) >= self._max_hits):
break break
hit = filtered[i]
result = _process_single_hit( result = _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
hit=hit, hit=hit,
......
...@@ -18,7 +18,7 @@ import glob ...@@ -18,7 +18,7 @@ import glob
import logging import logging
import os import os
import subprocess import subprocess
from typing import Sequence from typing import Sequence, Optional
from openfold.data import parsers from openfold.data import parsers
from openfold.data.tools import utils from openfold.data.tools import utils
...@@ -71,11 +71,12 @@ class HHSearch: ...@@ -71,11 +71,12 @@ class HHSearch:
def input_format(self) -> str: def input_format(self) -> str:
return 'a3m' return 'a3m'
def query(self, a3m: str) -> str: def query(self, a3m: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using HHsearch using a given a3m.""" """Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager() as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m") input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr") output_dir = query_tmp_dir if output_dir is None else output_dir
hhr_path = os.path.join(output_dir, "hhsearch_output.hhr")
with open(input_path, "w") as f: with open(input_path, "w") as f:
f.write(a3m) f.write(a3m)
...@@ -114,7 +115,8 @@ class HHSearch: ...@@ -114,7 +115,8 @@ class HHSearch:
hhr = f.read() hhr = f.read()
return hhr return hhr
def get_template_hits(self, @staticmethod
def get_template_hits(
output_string: str, output_string: str,
input_sequence: str input_sequence: str
) -> Sequence[parsers.TemplateHit]: ) -> Sequence[parsers.TemplateHit]:
......
...@@ -32,7 +32,8 @@ class Hmmsearch(object): ...@@ -32,7 +32,8 @@ class Hmmsearch(object):
binary_path: str, binary_path: str,
hmmbuild_binary_path: str, hmmbuild_binary_path: str,
database_path: str, database_path: str,
flags: Optional[Sequence[str]] = None): flags: Optional[Sequence[str]] = None
):
"""Initializes the Python hmmsearch wrapper. """Initializes the Python hmmsearch wrapper.
Args: Args:
...@@ -71,17 +72,23 @@ class Hmmsearch(object): ...@@ -71,17 +72,23 @@ class Hmmsearch(object):
def input_format(self) -> str: def input_format(self) -> str:
return 'sto' return 'sto'
def query(self, msa_sto: str) -> str: def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str:
"""Queries the database using hmmsearch using a given stockholm msa.""" """Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto, hmm = self.hmmbuild_runner.build_profile_from_sto(
model_construction='hand') msa_sto,
return self.query_with_hmm(hmm) model_construction='hand'
)
return self.query_with_hmm(hmm, output_dir)
def query_with_hmm(self, hmm: str) -> str: def query_with_hmm(self,
hmm: str,
output_dir: Optional[str] = None
) -> str:
"""Queries the database using hmmsearch using a given hmm.""" """Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir: with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
out_path = os.path.join(query_tmp_dir, 'output.sto') output_dir = query_tmp_dir if output_dir is None else output_dir
out_path = os.path.join(output_dir, 'hmm_output.sto')
with open(hmm_input_path, 'w') as f: with open(hmm_input_path, 'w') as f:
f.write(hmm) f.write(hmm)
...@@ -117,18 +124,14 @@ class Hmmsearch(object): ...@@ -117,18 +124,14 @@ class Hmmsearch(object):
return out_msa return out_msa
def get_template_hits(self, @staticmethod
def get_template_hits(
output_string: str, output_string: str,
input_sequence: str input_sequence: str
) -> Sequence[parsers.TemplateHit]: ) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool.""" """Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m( template_hits = parsers.parse_hmmsearch_sto(
output_string, output_string,
remove_first_row_gaps=False input_sequence,
)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False
) )
return template_hits return template_hits
...@@ -13,12 +13,26 @@ ...@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple from typing import Tuple
from openfold.utils import all_atom_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
dgram_from_positions,
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import one_hot from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from openfold.utils import geometry
from openfold.utils.tensor_utils import one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -85,20 +99,16 @@ class InputEmbedder(nn.Module): ...@@ -85,20 +99,16 @@ class InputEmbedder(nn.Module):
oh = one_hot(d, boundaries).type(ri.dtype) oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh) return self.linear_relpos(oh)
def forward( def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
tf: batch: Dict containing
"target_feat" features of shape [*, N_res, tf_dim] "target_feat":
ri: Features of shape [*, N_res, tf_dim]
"residue_index" features of shape [*, N_res] "residue_index":
msa: Features of shape [*, N_res]
"msa_feat" features of shape [*, N_clust, N_res, msa_dim] "msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns: Returns:
msa_emb: msa_emb:
[*, N_clust, N_res, C_m] MSA embedding [*, N_clust, N_res, C_m] MSA embedding
...@@ -106,6 +116,10 @@ class InputEmbedder(nn.Module): ...@@ -106,6 +116,10 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
""" """
tf = batch["target_feat"]
ri = batch["residue_index"]
msa = batch["msa_feat"]
# [*, N_res, c_z] # [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf) tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf) tf_emb_j = self.linear_tf_z_j(tf)
...@@ -126,6 +140,154 @@ class InputEmbedder(nn.Module): ...@@ -126,6 +140,154 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if(self.use_chain_relative):
self.no_bins = (
2 * max_relative_idx + 2 +
1 +
2 * max_relative_chain + 2
)
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = (asym_id[..., None] == asym_id[..., None, :])
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if(self.use_chain_relative):
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset)
)
rel_pos = torch.nn.functional.one_hot(
final_offset,
2 * self.max_relative_idx + 2,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) *
torch.ones_like(clipped_rel_chain)
)
rel_chain = torch.nn.functional.one_hot(
final_rel_chain,
2 * max_rel_chain + 2,
)
rel_feats.append(rel_chain)
else:
rel_pos = torch.nn.functional.one_hot(
clipped_offset, 2 * self.max_relative_idx + 1,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(
self.linear_relpos.weight.dtype
)
return self.linear_relpos(rel_feat)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
Embeds the output of an iteration of the model for recycling. Embeds the output of an iteration of the model for recycling.
...@@ -312,6 +474,102 @@ class TemplatePairEmbedder(nn.Module): ...@@ -312,6 +474,102 @@ class TemplatePairEmbedder(nn.Module):
return x return x
class TemplateEmbedder(nn.Module):
def __init__(
self,
config,
):
super().__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_pair_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class ExtraMSAEmbedder(nn.Module): class ExtraMSAEmbedder(nn.Module):
""" """
Embeds unclustered MSA sequences. Embeds unclustered MSA sequences.
...@@ -350,3 +608,315 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -350,3 +608,315 @@ class ExtraMSAEmbedder(nn.Module):
x = self.linear(x) x = self.linear(x)
return x return x
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.embed_angles:
ret["template_single_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super().__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super().__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
)
template_activations = self.template_projector(
template_activations,
)
out["template_single_embedding"] = (
template_activations
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_tensor(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds
...@@ -17,28 +17,25 @@ from functools import partial ...@@ -17,28 +17,25 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.data import data_transforms_multimer
from openfold.utils.feats import ( from openfold.utils.feats import (
pseudo_beta_fn, pseudo_beta_fn,
build_extra_msa_feat, build_extra_msa_feat,
build_template_angle_feat, dgram_from_positions,
build_template_pair_feat,
atom14_to_atom37, atom14_to_atom37,
) )
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateEmbedder,
TemplatePairEmbedder, TemplateEmbedderMultimer,
ExtraMSAEmbedder, ExtraMSAEmbedder,
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule from openfold.model.structure_module import StructureModule
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
) )
...@@ -69,24 +66,28 @@ class AlphaFold(nn.Module): ...@@ -69,24 +66,28 @@ class AlphaFold(nn.Module):
extra_msa_config = config.extra_msa extra_msa_config = config.extra_msa
# Main trunk + structure module # Main trunk + structure module
if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer(
**config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**config["input_embedder"], **config["input_embedder"],
) )
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"], **config["recycling_embedder"],
) )
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"], if(self.globals.is_multimer):
) self.template_embedder = TemplateEmbedderMultimer(
self.template_pair_embedder = TemplatePairEmbedder( template_config,
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
) )
self.template_pointwise_att = TemplatePointwiseAttention( else:
**template_config["template_pointwise_attention"], self.template_embedder = TemplateEmbedder(
template_config,
) )
self.extra_msa_embedder = ExtraMSAEmbedder( self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"], **extra_msa_config["extra_msa_embedder"],
) )
...@@ -96,7 +97,9 @@ class AlphaFold(nn.Module): ...@@ -96,7 +97,9 @@ class AlphaFold(nn.Module):
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
**config["evoformer_stack"], **config["evoformer_stack"],
) )
self.structure_module = StructureModule( self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**config["structure_module"], **config["structure_module"],
) )
...@@ -106,71 +109,6 @@ class AlphaFold(nn.Module): ...@@ -106,71 +109,6 @@ class AlphaFold(nn.Module):
self.config = config self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
...@@ -197,11 +135,7 @@ class AlphaFold(nn.Module): ...@@ -197,11 +135,7 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder( m, z = self.input_embedder(feats)
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]: if None in [m_1_prev, z_prev, x_prev]:
...@@ -257,40 +191,74 @@ class AlphaFold(nn.Module): ...@@ -257,40 +191,74 @@ class AlphaFold(nn.Module):
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
template_embeds = self.embed_templates(
if(self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
template_feats, template_feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
template_embeds = self.template_embedder(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
self.globals.chunk_size
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles: if(
self.config.template.embed_angles or
(self.globals.is_multimer and self.config.template.enabled)
):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_single_embedding"]],
dim=-3 dim=-3
) )
# [*, S, N] # [*, S, N]
if(not self.globals.is_multimer):
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) extra_msa_feat = extra_msa_fn(feats)
extra_msa_feat = self.extra_msa_embedder(extra_msa_feat)
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack( z = self.extra_msa_stack(
a, extra_msa_feat,
z, z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype), msa_mask=feats["extra_msa_mask"].to(dtype=extra_msa_feat.dtype),
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype), pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
...@@ -340,14 +308,14 @@ class AlphaFold(nn.Module): ...@@ -340,14 +308,14 @@ class AlphaFold(nn.Module):
return outputs, m_1_prev, z_prev, x_prev return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self): def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks: for b in self.extra_msa_stack.blocks:
b.ckpt = False b.ckpt = False
def _enable_activation_checkpointing(self): def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = ( self.template_embedder.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt self.config.template.template_pair_stack.blocks_per_ckpt
) )
self.evoformer.blocks_per_ckpt = ( self.evoformer.blocks_per_ckpt = (
......
...@@ -25,6 +25,9 @@ from openfold.np.residue_constants import ( ...@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
) )
from openfold.utils.geometry.quat_rigid import QuatRigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.feats import ( from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos, frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames, torsion_angles_to_frames,
...@@ -155,14 +158,14 @@ class PointProjection(nn.Module): ...@@ -155,14 +158,14 @@ class PointProjection(nn.Module):
def __init__(self, def __init__(self,
c_hidden: int, c_hidden: int,
num_points: int, num_points: int,
no_heads: int no_heads: int,
return_local_points: bool = False, return_local_points: bool = False,
): ):
super().__init__() super().__init__()
self.return_local_points = return_local_points self.return_local_points = return_local_points
self.no_heads = no_heads self.no_heads = no_heads
self.linear = Linear(c_hidden, 3 * num_points) self.linear = Linear(c_hidden, no_heads * 3 * num_points)
def forward(self, def forward(self,
activations: torch.Tensor, activations: torch.Tensor,
...@@ -171,11 +174,13 @@ class PointProjection(nn.Module): ...@@ -171,11 +174,13 @@ class PointProjection(nn.Module):
# TODO: Needs to run in high precision during training # TODO: Needs to run in high precision during training
points_local = self.linear(activations) points_local = self.linear(activations)
points_local = points_local.reshape( points_local = points_local.reshape(
points_local.shape[:-1], *points_local.shape[:-1],
self.no_heads, self.no_heads,
-1, -1,
) )
points_local = torch.split(points_local, 3, dim=-1) points_local = torch.split(
points_local, points_local.shape[-1] // 3, dim=-1
)
points_local = Vec3Array(*points_local) points_local = Vec3Array(*points_local)
points_global = rigids[..., None, None].apply_to_point(points_local) points_global = rigids[..., None, None].apply_to_point(points_local)
...@@ -184,7 +189,7 @@ class PointProjection(nn.Module): ...@@ -184,7 +189,7 @@ class PointProjection(nn.Module):
return points_global return points_global
# WEIGHTS CHANGED
class InvariantPointAttention(nn.Module): class InvariantPointAttention(nn.Module):
""" """
Implements Algorithm 22. Implements Algorithm 22.
...@@ -199,6 +204,7 @@ class InvariantPointAttention(nn.Module): ...@@ -199,6 +204,7 @@ class InvariantPointAttention(nn.Module):
no_v_points: int, no_v_points: int,
inf: float = 1e5, inf: float = 1e5,
eps: float = 1e-8, eps: float = 1e-8,
is_multimer: bool = False,
): ):
""" """
Args: Args:
...@@ -225,14 +231,14 @@ class InvariantPointAttention(nn.Module): ...@@ -225,14 +231,14 @@ class InvariantPointAttention(nn.Module):
self.no_v_points = no_v_points self.no_v_points = no_v_points
self.inf = inf self.inf = inf
self.eps = eps self.eps = eps
self.is_multimer = is_multimer
# These linear layers differ from their specifications in the # These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization. # supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default # Here as in the official source, they have bias and use the default
# Lecun initialization. # Lecun initialization.
hc = self.c_hidden * self.no_heads hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc) self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer))
self.linear_kv = Linear(self.c_s, 2 * hc)
self.linear_q_points = PointProjection( self.linear_q_points = PointProjection(
self.c_s, self.c_s,
...@@ -240,15 +246,25 @@ class InvariantPointAttention(nn.Module): ...@@ -240,15 +246,25 @@ class InvariantPointAttention(nn.Module):
self.no_heads self.no_heads
) )
if(is_multimer):
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection( self.linear_k_points = PointProjection(
self.c_s, self.c_s,
self.no_qk_points self.no_qk_points,
self.no_heads, self.no_heads,
) )
self.linear_v_points = PointProjection( self.linear_v_points = PointProjection(
self.c_s, self.c_s,
self.no_v_points self.no_v_points,
self.no_heads,
)
else:
self.linear_kv = Linear(self.c_s, 2 * hc)
self.linear_kv_points = PointProjection(
self.c_s,
self.no_qk_points + self.no_v_points,
self.no_heads, self.no_heads,
) )
...@@ -290,25 +306,48 @@ class InvariantPointAttention(nn.Module): ...@@ -290,25 +306,48 @@ class InvariantPointAttention(nn.Module):
####################################### #######################################
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
q = self.linear_q(s) q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1)) q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H, P_qk] # [*, N_res, H, P_qk]
q_pts = self.linear_q_points(s, r) q_pts = self.linear_q_points(s, r)
# The following two blocks are equivalent
# They're separated only to preserve compatibility with old AF weights
if(self.is_multimer):
# [*, N_res, H * C_hidden]
k = self.linear_k(s)
v = self.linear_v(s)
# [*, N_res, H, C_hidden]
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, P_qk, 3] # [*, N_res, H, P_qk, 3]
k_pts = self.linear_k_points(s, r) k_pts = self.linear_k_points(s, r)
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
v_pts = self.linear_v_points(s, r) v_pts = self.linear_v_points(s, r)
else:
# [*, N_res, H * 2 * C_hidden]
kv = self.linear_kv(s)
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
kv_pts = self.linear_kv_points(s, r)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
########################## ##########################
# Compute attention scores # Compute attention scores
...@@ -324,12 +363,14 @@ class InvariantPointAttention(nn.Module): ...@@ -324,12 +363,14 @@ class InvariantPointAttention(nn.Module):
a *= math.sqrt(1.0 / (3 * self.c_hidden)) a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
for c in q_pts:
print(type(c))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :] pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
pt_att = pt_att * pt_att + self.eps
# [*, N_res, N_res, H, P_q] # [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1)) pt_att = sum([c**2 for c in pt_att])
head_weights = self.softplus(self.head_weights).view( head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1)) *((1,) * len(pt_att.shape[:-2]) + (-1, 1))
) )
...@@ -364,9 +405,7 @@ class InvariantPointAttention(nn.Module): ...@@ -364,9 +405,7 @@ class InvariantPointAttention(nn.Module):
# As DeepMind explains, this manual matmul ensures that the operation # As DeepMind explains, this manual matmul ensures that the operation
# happens in float32. # happens in float32.
# [*, N_res, H, P_v] # [*, N_res, H, P_v]
o_pt = v_pts.tensor_dot( o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
permute_final_dims(a, (1, 2, 0)).unsqueeze(-1)
)
o_pt = o_pt.sum(dim=-3) o_pt = o_pt.sum(dim=-3)
# [*, N_res, H, P_v] # [*, N_res, H, P_v]
...@@ -493,6 +532,7 @@ class StructureModule(nn.Module): ...@@ -493,6 +532,7 @@ class StructureModule(nn.Module):
trans_scale_factor, trans_scale_factor,
epsilon, epsilon,
inf, inf,
is_multimer=False,
**kwargs, **kwargs,
): ):
""" """
...@@ -546,6 +586,7 @@ class StructureModule(nn.Module): ...@@ -546,6 +586,7 @@ class StructureModule(nn.Module):
self.trans_scale_factor = trans_scale_factor self.trans_scale_factor = trans_scale_factor
self.epsilon = epsilon self.epsilon = epsilon
self.inf = inf self.inf = inf
self.is_multimer = is_multimer
# To be lazily initialized later # To be lazily initialized later
self.default_frames = None self.default_frames = None
...@@ -567,6 +608,7 @@ class StructureModule(nn.Module): ...@@ -567,6 +608,7 @@ class StructureModule(nn.Module):
self.no_v_points, self.no_v_points,
inf=self.inf, inf=self.inf,
eps=self.epsilon, eps=self.epsilon,
is_multimer=self.is_multimer,
) )
self.ipa_dropout = nn.Dropout(self.dropout_rate) self.ipa_dropout = nn.Dropout(self.dropout_rate)
...@@ -588,26 +630,61 @@ class StructureModule(nn.Module): ...@@ -588,26 +630,61 @@ class StructureModule(nn.Module):
self.epsilon, self.epsilon,
) )
def forward( def _init_residue_constants(self, float_dtype, device):
self, if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
def _forward_monomer(self,
s, s,
z, z,
aatype, aatype,
mask=None, mask=None,
): ):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if mask is None: if mask is None:
# [*, N] # [*, N]
mask = s.new_ones(s.shape[:-1]) mask = s.new_ones(s.shape[:-1])
...@@ -690,51 +767,97 @@ class StructureModule(nn.Module): ...@@ -690,51 +767,97 @@ class StructureModule(nn.Module):
return outputs return outputs
def _init_residue_constants(self, float_dtype, device): def _forward_multimer(self,
if self.default_frames is None: s,
self.default_frames = torch.tensor( z,
restype_rigid_group_default_frame, aatype,
dtype=float_dtype, mask=None,
device=device, ):
requires_grad=False, if mask is None:
) # [*, N]
if self.group_idx is None: mask = s.new_ones(s.shape[:-1])
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, # [*, N, C_s]
device=device, s = self.layer_norm_s(s)
requires_grad=False,
# [*, N, N, C_z]
z = self.layer_norm_z(z)
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid3Array.identity(
s.shape[:-1],
s.device,
) )
if self.atom_mask is None: outputs = []
self.atom_mask = torch.tensor( for i in range(self.no_blocks):
restype_atom14_mask, # [*, N, C_s]
dtype=float_dtype, s = s + self.ipa(s, z, rigids, mask)
device=device, s = self.ipa_dropout(s)
requires_grad=False, s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids @ self.bb_update(s)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
rigids.scale_translation(self.trans_scale_factor),
angles,
aatype,
) )
if self.lit_positions is None:
self.lit_positions = torch.tensor( pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
restype_atom14_rigid_group_positions, all_frames_to_global,
dtype=float_dtype, aatype,
device=device,
requires_grad=False,
) )
def torsion_angles_to_frames(self, r, alpha, f): preds = {
# Lazily initialize the residue constants on the correct device "frames": rigids.scale_translation(self.trans_scale_factor).to_tensor7(),
self._init_residue_constants(alpha.dtype, alpha.device) "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
# Separated purely to make testing less annoying "unnormalized_angles": unnormalized_angles,
return torsion_angles_to_frames(r, alpha, f, self.default_frames) "angles": angles,
"positions": pred_xyz,
}
def frames_and_literature_positions_to_atom14_pos( outputs.append(preds)
self, r, f # [*, N, 8] # [*, N]
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def forward(
self,
s,
z,
aatype,
mask=None,
): ):
# Lazily initialize the residue constants on the correct device """
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) Args:
return frames_and_literature_positions_to_atom14_pos( s:
r, [*, N_res, C_s] single representation
f, z:
self.default_frames, [*, N_res, N_res, C_z] pair representation
self.group_idx, aatype:
self.atom_mask, [*, N_res] amino acid indices
self.lit_positions, mask:
) Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if(self.is_multimer):
outputs = self._forward_multimer(s, z, aatype, mask)
else:
outputs = self._forward_monomer(s, z, aatype, mask)
return outputs
...@@ -62,7 +62,7 @@ class Protein: ...@@ -62,7 +62,7 @@ class Protein:
b_factors: np.ndarray # [num_res, num_atom_type] b_factors: np.ndarray # [num_res, num_atom_type]
def __post_init__(self): def __post_init__(self):
if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS):
raise ValueError( raise ValueError(
f"Cannot build an instance with more than {PDB_MAX_CHAINS} " f"Cannot build an instance with more than {PDB_MAX_CHAINS} "
"chains because these cannot be written to PDB format" "chains because these cannot be written to PDB format"
......
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for all atom representations."""
from functools import partial
from typing import Dict, Text, Tuple
import torch
from openfold.np import residue_constants as rc
from openfold.utils import geometry, tensor_utils
import numpy as np
def squared_difference(x, y):
return jnp.square(x - y)
def get_rc_tensor(rc_np, aatype):
return torch.tensor(rc_np, device=aatype.device)[aatype]
def atom14_to_atom37(
atom14_data: torch.Tensor, # (*, N, 14, ...)
aatype: torch.Tensor # (*, N)
) -> torch.Tensor: # (*, N, 37, ...)
"""Convert atom14 to atom37 representation."""
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype)
no_batch_dims = len(aatype.shape) - 1
atom37_data = tensor_utils.batched_gather(
atom14_data,
idx_atom37_to_atom14,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1
)
atom37_mask = get_rc_tensor(rc.RESTYPE_ATOM37_MASK, aatype)
if len(atom14_data.shape) == no_batch_dims + 2:
atom37_data *= atom37_mask
elif len(atom14_data.shape) == no_batch_dims + 3:
atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype)
else:
raise ValueError("Incorrectly shaped data")
return atom37_data
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
"""Convert Atom37 positions to Atom14 positions."""
residx_atom14_to_atom37 = get_rc_tensor(
rc.RESTYPE_ATOM14_TO_ATOM37, aatype
)
no_batch_dims = len(aatype.shape)
atom14_mask = tensor_utils.batched_gather(
all_atom_mask,
residx_atom14_to_atom37,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1,
).to(torch.float32)
# create a mask for known groundtruth positions
atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype)
# gather the groundtruth positions
atom14_positions = tensor_utils.batched_gather(
all_atom_pos,
residx_atom14_to_atom37,
dim=no_batch_dims + 1,
no_batch_dims=no_batch_dims + 1,
),
atom14_positions = atom14_mask * atom14_positions
return atom14_positions, atom14_mask
def get_alt_atom14(aatype, positions: torch.Tensor, mask):
"""Get alternative atom14 positions."""
# pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14)
renaming_transform = get_rc_tensor(rc.RENAMING_MATRICES, aatype)
alternative_positions = torch.sum(
positions[..., None, :] * renaming_transform[..., None],
dim=-2
)
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position)
alternative_mask = torch.sum(mask[..., None] * renaming_transform, dim=-2)
return alternative_positions, alternative_mask
def atom37_to_frames(
aatype: torch.Tensor, # (...)
all_atom_positions: torch.Tensor, # (..., 37)
all_atom_mask: torch.Tensor, # (..., 37)
) -> Dict[Text, torch.Tensor]:
"""Computes the frames for the up to 8 rigid groups for each residue."""
# 0: 'backbone group',
# 1: 'pre-omega-group', (empty)
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
# 3: 'psi-group',
# 4,5,6,7: 'chi1,2,3,4-group'
no_batch_dims = len(aatype.shape) - 1
# Compute the gather indices for all residues in the chain.
# shape (N, 8, 3)
residx_rigidgroup_base_atom37_idx = get_rc_tensor(
rc.RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype
)
# Gather the base atom positions for each rigid group.
base_atom_pos = tensor_utils.batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim = no_batch_dims + 1,
batch_dims = no_batch_dims + 1,
)
# Compute the Rigids.
point_on_neg_x_axis = base_atom_pos[..., :, :, 0]
origin = base_atom_pos[..., :, :, 1]
point_on_xy_plane = base_atom_pos[..., :, :, 2]
gt_rotation = geometry.Rot3Array.from_two_vectors(
origin - point_on_neg_x_axis, point_on_xy_plane - origin
)
gt_frames = geometry.Rigid3Array(gt_rotation, origin)
# Compute a mask whether the group exists.
# (N, 8)
group_exists = get_rc_tensor(rc.RESTYPE_RIGIDGROUP_MASK, aatype)
# Compute a mask whether ground truth exists for the group
gt_atoms_exist = tensor_utils.batched_gather( # shape (N, 8, 3)
all_atom_mask.to(dtype=torch.float32),
residx_rigidgroup_base_atom37_idx,
batch_dims=no_batch_dims + 1,
)
gt_exists = torch.min(gt_atoms_exist, dim=-1) * group_exists # (N, 8)
# Adapt backbone frame to old convention (mirror x-axis and z-axis).
rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1])
rots[0, 0, 0] = -1
rots[0, 2, 2] = -1
gt_frames = gt_frames.compose_rotation(
geometry.Rot3Array.from_array(
torch.tensor(rots, device=aatype.device)
)
)
# The frames for ambiguous rigid groups are just rotated by 180 degree around
# the x-axis. The ambiguous group is always the last chi-group.
restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32)
restype_rigidgroup_rots = np.tile(
np.eye(3, dtype=np.float32), [21, 8, 1, 1]
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1
restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1
# Gather the ambiguity information for each residue.
residx_rigidgroup_is_ambiguous = torch.tensor(
restype_rigidgroup_is_ambiguous,
device=aatype.device,
)[aatype]
ambiguity_rot = torch.tensor(
restype_rigidgroup_rots,
device=aatype.device,
)[aatype]
ambiguity_rot = geometry.Rot3Array.from_array(
torch.Tensor(ambiguity_rot, device=aatype.device)
)
# Create the alternative ground truth frames.
alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot)
fix_shape = lambda x: x.reshape(x.shape[:-2] + (8,))
# reshape back to original residue layout
gt_frames = fix_shape(gt_frames)
gt_exists = fix_shape(gt_exists)
group_exists = fix_shape(group_exists)
residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous)
alt_gt_frames = fix_shape(alt_gt_frames)
return {
'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8)
'rigidgroups_gt_exists': gt_exists, # (..., 8)
'rigidgroups_group_exists': group_exists, # (..., 8)
'rigidgroups_group_is_ambiguous':
residx_rigidgroup_is_ambiguous, # (..., 8)
'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8)
}
def torsion_angles_to_frames(
aatype: torch.Tensor, # (N)
backb_to_global: geometry.Rigid3Array, # (N)
torsion_angles_sin_cos: torch.Tensor # (N, 7, 2)
) -> geometry.Rigid3Array: # (N, 8)
"""Compute rigid group frames from torsion angles."""
# Gather the default frames for all rigid groups.
# geometry.Rigid3Array with shape (N, 8)
m = get_rc_tensor(rc.restype_rigid_group_default_frame, aatype)
default_frames = geometry.Rigid3Array.from_array4x4(m)
# Create the rotation matrices according to the given angles (each frame is
# defined such that its rotation is around the x-axis).
sin_angles = torsion_angles_sin_cos[..., 0]
cos_angles = torsion_angles_sin_cos[..., 1]
# insert zero rotation for backbone group.
num_residues = aatype.shape[-1]
sin_angles = torch.cat(
[
torch.zeros_like(aatype).unsqueeze(),
sin_angles,
],
dim=-1)
cos_angles = torch.cat(
[
torch.ones_like(aatype).unsqueeze(),
cos_angles
],
dim=-1
)
zeros = torch.zeros_like(sin_angles)
ones = torch.ones_like(sin_angles)
# all_rots are geometry.Rot3Array with shape (..., N, 8)
all_rots = geometry.Rot3Array(
ones, zeros, zeros,
zeros, cos_angles, -sin_angles,
zeros, sin_angles, cos_angles
)
# Apply rotations to the frames.
all_frames = default_frames.compose_rotation(all_rots)
# chi2, chi3, and chi4 frames do not transform to the backbone frame but to
# the previous frame. So chain them up accordingly.
chi1_frame_to_backb = all_frames[..., 4]
chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[..., 5]
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[..., 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[..., 7]
all_frames_to_backb = Rigid3Array.cat(
[
all_frames[..., 0:5],
chi2_frame_to_backb[..., None],
chi3_frame_to_backb[..., None],
chi4_frame_to_backb[..., None]
],
dim=-1
)
# Create the global frames.
# shape (N, 8)
all_frames_to_global = backb_to_global[..., None] @ all_frames_to_backb
return all_frames_to_global
def frames_and_literature_positions_to_atom14_pos(
aatype: torch.Tensor, # (*, N)
all_frames_to_global: geometry.Rigid3Array # (N, 8)
) -> geometry.Vec3Array: # (*, N, 14)
"""Put atom literature positions (atom14 encoding) in each rigid group."""
# Pick the appropriate transform for every atom.
residx_to_group_idx = get_rc_tensor(
rc.restype_atom14_to_rigid_group,
aatype
)
group_mask = torch.nn.functional.one_hot(
residx_to_group_idx,
num_classes=8
) # shape (*, N, 14, 8)
# geometry.Rigid3Array with shape (N, 14)
map_atoms_to_global = all_frames_to_global[..., None, :] * group_mask
map_atoms_to_global = map_atoms_to_global.map_tensor_fn(
partial(torch.sum, dim=-1)
)
# Gather the literature atom positions for each residue.
# geometry.Vec3Array with shape (N, 14)
lit_positions = geometry.Vec3Array.from_array(
get_rc_tensor(
rc.restype_atom14_rigid_group_positions,
aatype
)
)
# Transform each atom from its local frame to the global frame.
# geometry.Vec3Array with shape (N, 14)
pred_positions = map_atoms_to_global.apply_to_point(lit_positions)
# Mask out non-existing atoms.
mask = get_rc_tensor(rc.restype_atom14_mask, aatype)
pred_positions = pred_positions * mask
return pred_positions
def extreme_ca_ca_distance_violations(
positions: geometry.Vec3Array, # (N, 37(14))
mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
max_angstrom_tolerance=1.5,
eps: float = 1e-6
) -> torch.Tensor:
"""Counts residues whose Ca is a large distance from its neighbor."""
this_ca_pos = positions[..., :-1, 1] # (N - 1,)
this_ca_mask = mask[..., :-1, 1] # (N - 1)
next_ca_pos = positions[..., 1:, 1] # (N - 1,)
next_ca_mask = mask[..., 1:, 1] # (N - 1)
has_no_gap_mask = (
(residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
).astype(torch.float32)
ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, eps)
violations = (ca_ca_distance - rc.ca_ca) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
return tensor_utils.masked_mean(mask=mask, value=violations, dim=-1)
def get_chi_atom_indices(device: torch.device):
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle]
)
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return torch.tensor(chi_atom_indices, device=device)
def compute_chi_angles(
positions: geometry.Vec3Array,
mask: torch.Tensor,
aatype: torch.Tensor
):
"""Computes the chi angles given all atom positions and the amino acid type.
Args:
positions: A Vec3Array of shape
[num_res, rc.atom_type_num], with positions of
atoms needed to calculate chi angles. Supports up to 1 batch dimension.
mask: An optional tensor of shape
[num_res, rc.atom_type_num] that masks which atom
positions are set for each residue. If given, then the chi mask will be
set to 1 for a chi angle only if the amino acid has that chi angle and all
the chi atoms needed to calculate that chi angle are set. If not given
(set to None), the chi mask will be set to 1 for a chi angle if the amino
acid has that chi angle and whether the actual atoms needed to calculate
it were set will be ignored.
aatype: A tensor of shape [num_res] with amino acid type integer
code (0 to 21). Supports up to 1 batch dimension.
Returns:
A tuple of tensors (chi_angles, mask), where both have shape
[num_res, 4]. The mask masks out unused chi angles for amino acid
types that have less than 4 chi angles. If atom_positions_mask is set, the
chi mask will also mask out uncomputable chi angles.
"""
# Don't assert on the num_res and batch dimensions as they might be unknown.
assert positions.shape[-1] == rc.atom_type_num
assert mask.shape[-1] == rc.atom_type_num
no_batch_dims = len(aatype.shape) - 1
# Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
chi_atom_indices = get_chi_atom_indices(aatype.device)
# DISCREPANCY: DeepMind doesn't remove the gaps here. I don't know why
# theirs works.
aatype_gapless = torch.clamp(aatype, max=20)
# Select atoms to compute chis. Shape: [*, num_res, chis=4, atoms=4].
atom_indices = chi_atom_indices[aatype_gapless]
# Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3].
chi_angle_atoms = positions.map_tensor_fn(
partial(
tensor_utils.batched_gather,
inds=atom_indices,
dim=-1,
no_batch_dims=no_batch_dims + 1
)
)
a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)]
chi_angles = geometry.dihedral_angle(a, b, c, d)
# Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4].
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = torch.tensor(chi_angles_mask, device=aatype.device)
# Compute the chi angle mask. Shape [num_res, chis=4].
chi_mask = chi_angles_mask[aatype_gapless]
# The chi_mask is set to 1 only when all necessary chi angle atoms were set.
# Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4].
chi_angle_atoms_mask = tensor_utils.batched_gather(
mask,
atom_indices,
dim=-1,
no_batch_dims=no_batch_dims + 1
)
# Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4].
chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, dim=-1)
chi_mask = chi_mask * chi_angle_atoms_mask.to(torch.float32)
return chi_angles, chi_mask
def make_transform_from_reference(
a_xyz: geometry.Vec3Array,
b_xyz: geometry.Vec3Array,
c_xyz: geometry.Vec3Array
) -> geometry.Rigid3Array:
"""Returns rotation and translation matrices to convert from reference.
Note that this method does not take care of symmetries. If you provide the
coordinates in the non-standard way, the A atom will end up in the negative
y-axis rather than in the positive y-axis. You need to take care of such
cases in your code.
Args:
a_xyz: A Vec3Array.
b_xyz: A Vec3Array.
c_xyz: A Vec3Array.
Returns:
A Rigid3Array which, when applied to coordinates in a canonicalized
reference frame, will give coordinates approximately equal
the original coordinates (in the global frame).
"""
rotation = geometry.Rot3Array.from_two_vectors(
c_xyz - b_xyz,
a_xyz - b_xyz
)
return geometry.Rigid3Array(rotation, b_xyz)
def make_backbone_affine(
positions: geometry.Vec3Array,
mask: torch.Tensor,
aatype: torch.Tensor,
) -> Tuple[geometry.Rigid3Array, torch.Tensor]:
a = rc.atom_order['N']
b = rc.atom_order['CA']
c = rc.atom_order['C']
rigid_mask = (mask[..., a] * mask[..., b] * mask[..., c])
rigid = make_transform_from_reference(
a_xyz=positions[..., a],
b_xyz=positions[..., b],
c_xyz=positions[..., c],
)
return rigid, rigid_mask
from argparse import HelpFormatter
from operator import attrgetter
class ArgparseAlphabetizer(HelpFormatter):
"""
Sorts the optional arguments of an argparse parser alphabetically
"""
@staticmethod
def sort_actions(actions):
return sorted(actions, key=attrgetter("option_strings"))
# Formats the help message
def add_arguments(self, actions):
actions = ArgparseAlphabetizer.sort_actions(actions)
super(ArgparseAlphabetizer, self).add_arguments(actions)
# Formats the usage message
def add_usage(self, usage, actions, groups, prefix=None):
actions = ArgparseAlphabetizer.sort_actions(actions)
args = usage, actions, groups, prefix
super(ArgparseAlphabetizer, self).add_usage(*args)
def remove_arguments(parser, args):
for arg in args:
for action in parser._actions:
opts = vars(action)["option_strings"]
if(arg in opts):
parser._handle_conflict_resolve(None, [(arg, action)])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment