"vscode:/vscode.git/clone" did not exist on "04e50aba33109d9f60c609a0f100ebb413ffabad"
Commit a3c2ae51 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add TM calculations, move testing wrappers

parent 9eda0b43
...@@ -7,6 +7,7 @@ from operator import add ...@@ -7,6 +7,7 @@ from operator import add
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
MSA_FEATURE_NAMES = [ MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa' 'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'
...@@ -535,3 +536,10 @@ def make_atom14_masks(protein): ...@@ -535,3 +536,10 @@ def make_atom14_masks(protein):
protein['atom37_atom_exists'] = residx_atom37_mask protein['atom37_atom_exists'] = residx_atom37_mask
return protein return protein
def make_atom14_masks_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
...@@ -17,7 +17,11 @@ import torch ...@@ -17,7 +17,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.loss import compute_plddt from openfold.utils.loss import (
compute_plddt,
compute_tm,
compute_predicted_aligned_error,
)
class AuxiliaryHeads(nn.Module): class AuxiliaryHeads(nn.Module):
...@@ -71,6 +75,12 @@ class AuxiliaryHeads(nn.Module): ...@@ -71,6 +75,12 @@ class AuxiliaryHeads(nn.Module):
if(self.config.tm.enabled): if(self.config.tm.enabled):
tm_logits = self.tm(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(compute_predicted_aligned_error(
tm_logits, **self.config.tm,
))
return aux_out return aux_out
......
...@@ -75,83 +75,6 @@ def get_chi_atom_indices(): ...@@ -75,83 +75,6 @@ def get_chi_atom_indices():
return chi_atom_indices return chi_atom_indices
def compute_residx(batch):
out = {}
float_type = batch["seq_mask"].dtype
aatype = batch["aatype"]
restype_atom14_to_atom37 = [] # mapping (restype, atom37) --> atom14
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[
rc.restype_1to3[rt]
]
restype_atom14_to_atom37.append([
(rc.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in rc.atom_types
])
restype_atom14_mask.append(
[(1. if name else 0.) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = aatype.new_tensor(
restype_atom14_to_atom37
)
restype_atom37_to_atom14 = aatype.new_tensor(
restype_atom37_to_atom14
)
restype_atom14_mask = batch["seq_mask"].new_tensor(
restype_atom14_mask
)
residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype]
residx_atom14_mask = restype_atom14_mask[aatype]
out["residx_atom14_to_atom37"] = residx_atom14_to_atom37
out["atom14_atom_exists"] = residx_atom14_mask
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype]
out["residx_atom37_to_atom14"] = residx_atom37_to_atom14
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=float_type)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[aatype]
out["atom37_atom_exists"] = residx_atom37_mask
return out
def compute_residx_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
out = compute_residx(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def atom14_to_atom37(atom14, batch): def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather( atom37_data = batched_gather(
atom14, atom14,
......
...@@ -18,7 +18,7 @@ import ml_collections ...@@ -18,7 +18,7 @@ import ml_collections
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Dict, Optional from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils import feats from openfold.utils import feats
...@@ -28,6 +28,7 @@ from openfold.utils.tensor_utils import ( ...@@ -28,6 +28,7 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
masked_mean, masked_mean,
permute_final_dims, permute_final_dims,
batched_gather,
) )
...@@ -450,6 +451,100 @@ def distogram_loss( ...@@ -450,6 +451,100 @@ def distogram_loss(
return mean return mean
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = breaks + step / 2
bin_centers = torch.cat([bin_centers, [bin_centers[-1] + step]], dim=0)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1]
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_aligned_error, max_predicted_aligned_error = (
_calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs
)
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
) -> torch.Tensor:
if(residue_weights is None):
residue_weights = np.ones(logits.shape[-2])
boundaries = torch.linspace(
0,
max_bin,
steps=(no_bins - 1),
device=logits.device
)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1. / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
def tm_loss( def tm_loss(
logits, logits,
final_affine_tensor, final_affine_tensor,
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import unittest import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.features.data_transforms import make_atom14_masks
from openfold.utils.affine_utils import T, affine_vector_to_4x4 from openfold.utils.affine_utils import T, affine_vector_to_4x4
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
...@@ -310,7 +311,7 @@ class TestLoss(unittest.TestCase): ...@@ -310,7 +311,7 @@ class TestLoss(unittest.TestCase):
def _build_extra_feats_np(): def _build_extra_feats_np():
b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) b = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
b.update(feats.build_ambiguity_feats(b)) b.update(feats.build_ambiguity_feats(b))
b.update(feats.compute_residx(b)) b.update(make_atom14_masks(b))
return tensor_tree_map(lambda t: np.array(t), b) return tensor_tree_map(lambda t: np.array(t), b)
batch = _build_extra_feats_np() batch = _build_extra_feats_np()
......
...@@ -18,6 +18,7 @@ import torch.nn as nn ...@@ -18,6 +18,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import unittest import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.features.data_transforms import make_atom14_masks
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map from openfold.utils.tensor_utils import tree_map, tensor_tree_map
...@@ -73,7 +74,7 @@ class TestModel(unittest.TestCase): ...@@ -73,7 +74,7 @@ class TestModel(unittest.TestCase):
batch["seq_mask"] = torch.randint( batch["seq_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_res) low=0, high=2, size=(batch_size, n_res)
).float() ).float()
batch.update(feats.compute_residx(batch)) batch.update(make_atom14_masks(batch))
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.no_cycles) t.unsqueeze(-1).expand(*t.shape, c.no_cycles)
......
...@@ -16,6 +16,7 @@ import torch ...@@ -16,6 +16,7 @@ import torch
import numpy as np import numpy as np
import unittest import unittest
from openfold.features.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -157,7 +158,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -157,7 +158,7 @@ class TestStructureModule(unittest.TestCase):
axis=0 axis=0
) )
batch.update(feats.compute_residx_np(batch)) batch.update(make_atom14_masks_np(batch))
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module" "alphafold/alphafold_iteration/structure_module"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment