Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
71189e20
"runtime/rust/python-wheel/examples/hello_world/client.py" did not exist on "ffbc06ccf7c9abb40123f3d6ea047caff4609c6c"
Commit
71189e20
authored
Oct 11, 2021
by
Gustaf Ahdritz
Browse files
Fix TM calculation bugs
parent
a3c2ae51
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
openfold/utils/loss.py
openfold/utils/loss.py
+9
-4
No files found.
openfold/utils/loss.py
View file @
71189e20
...
@@ -453,8 +453,10 @@ def distogram_loss(
...
@@ -453,8 +453,10 @@ def distogram_loss(
def
_calculate_bin_centers
(
boundaries
:
torch
.
Tensor
):
def
_calculate_bin_centers
(
boundaries
:
torch
.
Tensor
):
step
=
boundaries
[
1
]
-
boundaries
[
0
]
step
=
boundaries
[
1
]
-
boundaries
[
0
]
bin_centers
=
breaks
+
step
/
2
bin_centers
=
boundaries
+
step
/
2
bin_centers
=
torch
.
cat
([
bin_centers
,
[
bin_centers
[
-
1
]
+
step
]],
dim
=
0
)
bin_centers
=
torch
.
cat
(
[
bin_centers
,
(
bin_centers
[
-
1
]
+
step
).
unsqueeze
(
-
1
)],
dim
=
0
)
return
bin_centers
return
bin_centers
...
@@ -463,7 +465,6 @@ def _calculate_expected_aligned_error(
...
@@ -463,7 +465,6 @@ def _calculate_expected_aligned_error(
aligned_distance_error_probs
:
torch
.
Tensor
,
aligned_distance_error_probs
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
bin_centers
=
_calculate_bin_centers
(
alignment_confidence_breaks
)
return
(
return
(
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
torch
.
sum
(
aligned_distance_error_probs
*
bin_centers
,
dim
=-
1
),
bin_centers
[
-
1
]
bin_centers
[
-
1
]
...
@@ -474,6 +475,7 @@ def compute_predicted_aligned_error(
...
@@ -474,6 +475,7 @@ def compute_predicted_aligned_error(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
max_bin
:
int
=
31
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
no_bins
:
int
=
64
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Computes aligned confidence metrics from logits.
"""Computes aligned confidence metrics from logits.
...
@@ -516,9 +518,11 @@ def compute_tm(
...
@@ -516,9 +518,11 @@ def compute_tm(
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
residue_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
max_bin
:
int
=
31
,
max_bin
:
int
=
31
,
no_bins
:
int
=
64
,
no_bins
:
int
=
64
,
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
(
residue_weights
is
None
):
if
(
residue_weights
is
None
):
residue_weights
=
np
.
ones
(
logits
.
shape
[
-
2
])
residue_weights
=
logits
.
new_
ones
(
logits
.
shape
[
-
2
])
boundaries
=
torch
.
linspace
(
boundaries
=
torch
.
linspace
(
0
,
0
,
...
@@ -529,6 +533,7 @@ def compute_tm(
...
@@ -529,6 +533,7 @@ def compute_tm(
bin_centers
=
_calculate_bin_centers
(
boundaries
)
bin_centers
=
_calculate_bin_centers
(
boundaries
)
torch
.
sum
(
residue_weights
)
torch
.
sum
(
residue_weights
)
n
=
logits
.
shape
[
-
2
]
clipped_n
=
max
(
n
,
19
)
clipped_n
=
max
(
n
,
19
)
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
/
3
)
-
1.8
d0
=
1.24
*
(
clipped_n
-
15
)
**
(
1.
/
3
)
-
1.8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment