"runtime/rust/python-wheel/examples/hello_world/client.py" did not exist on "ffbc06ccf7c9abb40123f3d6ea047caff4609c6c"
Commit 71189e20 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix TM calculation bugs

parent a3c2ae51
...@@ -453,8 +453,10 @@ def distogram_loss( ...@@ -453,8 +453,10 @@ def distogram_loss(
def _calculate_bin_centers(boundaries: torch.Tensor): def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0] step = boundaries[1] - boundaries[0]
bin_centers = breaks + step / 2 bin_centers = boundaries + step / 2
bin_centers = torch.cat([bin_centers, [bin_centers[-1] + step]], dim=0) bin_centers = torch.cat(
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
)
return bin_centers return bin_centers
...@@ -463,7 +465,6 @@ def _calculate_expected_aligned_error( ...@@ -463,7 +465,6 @@ def _calculate_expected_aligned_error(
aligned_distance_error_probs: torch.Tensor, aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks) bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return ( return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1] bin_centers[-1]
...@@ -474,6 +475,7 @@ def compute_predicted_aligned_error( ...@@ -474,6 +475,7 @@ def compute_predicted_aligned_error(
logits: torch.Tensor, logits: torch.Tensor,
max_bin: int = 31, max_bin: int = 31,
no_bins: int = 64, no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits. """Computes aligned confidence metrics from logits.
...@@ -516,9 +518,11 @@ def compute_tm( ...@@ -516,9 +518,11 @@ def compute_tm(
residue_weights: Optional[torch.Tensor] = None, residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31, max_bin: int = 31,
no_bins: int = 64, no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
if(residue_weights is None): if(residue_weights is None):
residue_weights = np.ones(logits.shape[-2]) residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace( boundaries = torch.linspace(
0, 0,
...@@ -529,6 +533,7 @@ def compute_tm( ...@@ -529,6 +533,7 @@ def compute_tm(
bin_centers = _calculate_bin_centers(boundaries) bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights) torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19) clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8 d0 = 1.24 * (clipped_n - 15) ** (1./3) - 1.8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment