Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
15105078
Commit
15105078
authored
Jun 28, 2023
by
Geoffrey Yu
Browse files
added batch_size dimesion to compute_tm
parent
a3ea7c65
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
openfold/utils/loss.py
openfold/utils/loss.py
+2
-2
No files found.
openfold/utils/loss.py
View file @
15105078
...
@@ -694,10 +694,10 @@ def compute_tm(
...
@@ -694,10 +694,10 @@ def compute_tm(
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
predicted_tm_term
=
torch
.
sum
(
probs
*
tm_per_bin
,
dim
=-
1
)
n
=
residue_weights
.
shape
[
-
1
]
n
=
residue_weights
.
shape
[
-
1
]
pair_mask
=
residue_weights
.
new_ones
((
n
,
n
),
dtype
=
torch
.
int32
)
pair_mask
=
residue_weights
.
new_ones
((
1
,
n
,
n
),
dtype
=
torch
.
int32
)
if
interface
:
if
interface
:
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:]).
to
(
dtype
=
pair_mask
.
dtype
)
pair_mask
*=
(
asym_id
[...,
None
]
!=
asym_id
[...,
None
,
:]).
to
(
dtype
=
pair_mask
.
dtype
)
predicted_tm_term
*=
pair_mask
predicted_tm_term
*=
pair_mask
pair_residue_weights
=
pair_mask
*
(
pair_residue_weights
=
pair_mask
*
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment