Commit 71189e20 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix TM calculation bugs

parent a3c2ae51
......@@ -453,8 +453,10 @@ def distogram_loss(
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = breaks + step / 2
bin_centers = torch.cat([bin_centers, [bin_centers[-1] + step]], dim=0)
bin_centers = boundaries + step / 2
bin_centers = torch.cat(
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
)
return bin_centers
......@@ -463,7 +465,6 @@ def _calculate_expected_aligned_error(
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1]
......@@ -474,6 +475,7 @@ def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
......@@ -516,9 +518,11 @@ def compute_tm(
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if(residue_weights is None):
residue_weights = np.ones(logits.shape[-2])
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(
0,
......@@ -529,6 +533,7 @@ def compute_tm(
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment