"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "25cd01e092157642573125f0076ff0afae2a9dea"
Commit a3c2ae51 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add TM calculations, move testing wrappers

parent 9eda0b43
......@@ -7,6 +7,7 @@ from operator import add
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'
......@@ -535,3 +536,10 @@ def make_atom14_masks(protein):
protein['atom37_atom_exists'] = residx_atom37_mask
return protein
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
......@@ -17,7 +17,11 @@ import torch
import torch.nn as nn
from openfold.model.primitives import Linear
from openfold.utils.loss import compute_plddt
from openfold.utils.loss import (
compute_plddt,
compute_tm,
compute_predicted_aligned_error,
)
class AuxiliaryHeads(nn.Module):
......@@ -70,7 +74,13 @@ class AuxiliaryHeads(nn.Module):
if(self.config.tm.enabled):
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(compute_predicted_aligned_error(
tm_logits, **self.config.tm,
))
return aux_out
......
......@@ -75,83 +75,6 @@ def get_chi_atom_indices():
return chi_atom_indices
def compute_residx(batch):
out = {}
float_type = batch["seq_mask"].dtype
aatype = batch["aatype"]
restype_atom14_to_atom37 = [] # mapping (restype, atom37) --> atom14
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[
rc.restype_1to3[rt]
]
restype_atom14_to_atom37.append([
(rc.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in rc.atom_types
])
restype_atom14_mask.append(
[(1. if name else 0.) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = aatype.new_tensor(
restype_atom14_to_atom37
)
restype_atom37_to_atom14 = aatype.new_tensor(
restype_atom37_to_atom14
)
restype_atom14_mask = batch["seq_mask"].new_tensor(
restype_atom14_mask
)
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
residx_atom14_mask = restype_atom14_mask[aatype]
out["residx_atom14_to_atom37"] = residx_atom14_to_atom37
out["atom14_atom_exists"] = residx_atom14_mask
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype]
out["residx_atom37_to_atom14"] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=float_type)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[aatype]
out["atom37_atom_exists"] = residx_atom37_mask
return out
def compute_residx_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
out = compute_residx(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
......
......@@ -18,7 +18,7 @@ import ml_collections
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, Optional
from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
from openfold.utils import feats
......@@ -28,6 +28,7 @@ from openfold.utils.tensor_utils import (
tensor_tree_map,
masked_mean,
permute_final_dims,
batched_gather,
)
......@@ -450,6 +451,100 @@ def distogram_loss(
return mean
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = breaks + step / 2
bin_centers = torch.cat([bin_centers, [bin_centers[-1] + step]], dim=0)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1]
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_aligned_error, max_predicted_aligned_error = (
_calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs
)
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
) -> torch.Tensor:
if(residue_weights is None):
residue_weights = np.ones(logits.shape[-2])
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1. / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
def tm_loss(
logits,
final_affine_tensor,
......
......@@ -19,6 +19,7 @@ import numpy as np
import unittest
import ml_collections as mlc
from openfold.features.data_transforms import make_atom14_masks
from openfold.utils.affine_utils import T, affine_vector_to_4x4
import openfold.utils.feats as feats
from openfold.utils.loss import (
......@@ -310,7 +311,7 @@ class TestLoss(unittest.TestCase):
def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b.update(feats.build_ambiguity_feats(b))
b.update(feats.compute_residx(b))
b.update(make_atom14_masks(b))
return tensor_tree_map(lambda t: np.array(t), b)
batch = _build_extra_feats_np()
......
......@@ -18,6 +18,7 @@ import torch.nn as nn
import numpy as np
import unittest
from openfold.config import model_config
from openfold.features.data_transforms import make_atom14_masks
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
......@@ -73,7 +74,7 @@ class TestModel(unittest.TestCase):
batch["seq_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_res)
).float()
batch.update(feats.compute_residx(batch))
batch.update(make_atom14_masks(batch))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.no_cycles)
......
......@@ -16,6 +16,7 @@ import torch
import numpy as np
import unittest
from openfold.features.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -157,7 +158,7 @@ class TestStructureModule(unittest.TestCase):
axis=0
)
batch.update(feats.compute_residx_np(batch))
batch.update(make_atom14_masks_np(batch))
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment