Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2c4d4183
Commit
2c4d4183
authored
Jul 14, 2023
by
Geoffrey Yu
Browse files
remove unecessary print statements
parent
e4d7f6d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
11 deletions
+4
-11
openfold/utils/loss.py
openfold/utils/loss.py
+4
-11
No files found.
openfold/utils/loss.py
View file @
2c4d4183
...
...
@@ -1610,7 +1610,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns:
Masked MSA loss
"""
print
(
f
"logits shape:
{
logits
.
shape
}
true_msa shape:
{
true_msa
.
shape
}
"
)
errors
=
softmax_cross_entropy
(
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
num_classes
)
)
...
...
@@ -1881,13 +1880,12 @@ def greedy_align(
return
align
def
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
):
def
merge_labels
(
per_asym_residue_index
,
labels
,
align
):
"""
batch:
labels: list of
label dicts, each with shape [nk, *]
align: list of
int, such as [2, None, 0, 1]
, each entry specify the corresponding label of the asym.
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
labels: list of
original ground truth feats
align: list of
tuples
, each entry specify the corresponding label of the asym.
"""
num_res
=
batch
[
"msa_mask"
].
shape
[
-
1
]
outs
=
{}
for
k
,
v
in
labels
[
0
].
items
():
cur_out
=
{}
...
...
@@ -1904,11 +1902,7 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
cur_out
=
[
x
[
1
]
for
x
in
sorted
(
cur_out
.
items
())]
if
len
(
cur_out
)
>
0
:
new_v
=
torch
.
concat
(
cur_out
,
dim
=
dimension_to_merge
)
print
(
f
"k is
{
k
}
shape:
{
label
.
shape
}
and dimension_to_merge:
{
dimension_to_merge
}
"
)
outs
[
k
]
=
new_v
print
(
f
"finished merging"
)
for
k
,
v
in
outs
.
items
():
print
(
f
"
{
k
}
:
{
v
.
shape
}
"
)
return
outs
...
...
@@ -2098,7 +2092,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
gc
.
collect
()
print
(
f
"finished multi-chain permutation and final align is
{
align
}
"
)
merged_labels
=
merge_labels
(
batch
,
per_asym_residue_index
,
labels
,
align
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment