confidence.py 5.44 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from typing import Dict, Optional, Tuple


def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor:
    """Computes per-residue pLDDT from logits.
    Args:
        logits: [num_res, num_bins] output from the PredictedLDDTHead.
    Returns:
        plddt: [num_res] per-residue pLDDT.
    """
    num_bins = plddt_logits.shape[-1]
    bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1)
    bin_width = 1.0 / num_bins
    bounds = torch.arange(
        start=0.5 * bin_width, end=1.0, step=bin_width, device=plddt_logits.device
    )
    plddt = torch.sum(
        bin_probs * bounds.view(*((1,) * len(bin_probs.shape[:-1])), *bounds.shape),
        dim=-1,
    )
    return plddt


def compute_bin_values(breaks: torch.Tensor):
    """Gets the bin centers from the bin edges.
    Args:
        breaks: [num_bins - 1] the error bin edges.
    Returns:
        bin_centers: [num_bins] the error bin centers.
    """
    step = breaks[1] - breaks[0]
    bin_values = breaks + step / 2
    bin_values = torch.cat(
        [bin_values, (bin_values[-1] + step).unsqueeze(-1)], dim=0
    )
    return bin_values


def compute_predicted_aligned_error(
    bin_edges: torch.Tensor,
    bin_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Calculates expected aligned distance errors for every pair of residues.
    Args:
        alignment_confidence_breaks: [num_bins - 1] the error bin edges.
        aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
        probs for each error bin, for each pair of residues.
    Returns:
        predicted_aligned_error: [num_res, num_res] the expected aligned distance
        error for each pair of residues.
        max_predicted_aligned_error: The maximum predicted error possible.
    """
    bin_values = compute_bin_values(bin_edges)
    return torch.sum(bin_probs * bin_values, dim=-1)


def predicted_aligned_error(
    pae_logits: torch.Tensor,
    max_bin: int = 31,
    num_bins: int = 64,
    **kwargs,
) -> Dict[str, torch.Tensor]:
    """Computes aligned confidence metrics from logits.
    Args:
        logits: [num_res, num_res, num_bins] the logits output from
        PredictedAlignedErrorHead.
        breaks: [num_bins - 1] the error bin edges.
    Returns:
        aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
        aligned error probabilities over bins for each residue pair.
        predicted_aligned_error: [num_res, num_res] the expected aligned distance
        error for each pair of residues.
        max_predicted_aligned_error: The maximum predicted error possible.
    """
    bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1)
    bin_edges = torch.linspace(0, max_bin, steps=(num_bins - 1), device=pae_logits.device)

    predicted_aligned_error = compute_predicted_aligned_error(
        bin_edges=bin_edges,
        bin_probs=bin_probs,
    )

    return {
        "aligned_error_probs_per_bin": bin_probs,
        "predicted_aligned_error": predicted_aligned_error,
    }


def predicted_tm_score(
    pae_logits: torch.Tensor,
    residue_weights: Optional[torch.Tensor] = None,
    max_bin: int = 31,
    num_bins: int = 64,
    eps: float = 1e-8,
    asym_id: Optional[torch.Tensor] = None,
    interface: bool = False,
    **kwargs,
) -> torch.Tensor:
    """Computes predicted TM alignment or predicted interface TM alignment score.
    Args:
        logits: [num_res, num_res, num_bins] the logits output from
        PredictedAlignedErrorHead.
        breaks: [num_bins] the error bins.
        residue_weights: [num_res] the per residue weights to use for the
        expectation.
        asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
        ipTM calculation, i.e. when interface=True.
        interface: If True, interface predicted TM score is computed.
    Returns:
        ptm_score: The predicted TM alignment or the predicted iTM score.
    """
    pae_logits = pae_logits.float()
    if residue_weights is None:
        residue_weights = pae_logits.new_ones(pae_logits.shape[:-2])

    breaks = torch.linspace(0, max_bin, steps=(num_bins - 1), device=pae_logits.device)

    def tm_kernal(nres):
        clipped_n = max(nres, 19)
        d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3.0) - 1.8
        return lambda x: 1.0 / (1.0 + (x / d0) ** 2)

    def rmsd_kernal(eps):  # leave for compute pRMS
        return lambda x: 1. / (x + eps)

    bin_centers = compute_bin_values(breaks)
    probs = torch.nn.functional.softmax(pae_logits, dim=-1)
    tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers)
    # tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
    # rmsd_per_bin = rmsd_kernal()(bin_centers)
    predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)

    pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape)
    if interface:
        assert asym_id is not None, "must provide asym_id for iptm calculation."
        pair_mask *= asym_id[..., :, None] != asym_id[..., None, :]

    predicted_tm_term *= pair_mask

    pair_residue_weights = pair_mask * (
        residue_weights[None, :] * residue_weights[:, None]
    )
    normed_residue_mask = pair_residue_weights / (
        eps + pair_residue_weights.sum(dim=-1, keepdim=True)
    )

    per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
    weighted = per_alignment * residue_weights
    ret = per_alignment.gather(
        dim=-1, index=weighted.max(dim=-1, keepdim=True).indices
    ).squeeze(dim=-1)
    return ret