Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e9d2d893
Commit
e9d2d893
authored
Jan 13, 2022
by
Kolja Stahl
Browse files
val_loss fix and stop sampling recycling iterations in validation
parent
70d6bda5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
8 deletions
+10
-8
openfold/data/data_modules.py
openfold/data/data_modules.py
+8
-7
train_openfold.py
train_openfold.py
+2
-1
No files found.
openfold/data/data_modules.py
View file @
e9d2d893
...
@@ -283,13 +283,14 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
...
@@ -283,13 +283,14 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs
.
append
(
keyed_probs
.
append
(
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
(
"use_clamped_fape"
,
[
1
-
clamp_prob
,
clamp_prob
])
)
)
if
(
self
.
config
.
supervised
.
uniform_recycling
):
recycling_probs
=
[
if
(
self
.
stage
==
"train"
and
self
.
config
.
supervised
.
uniform_recycling
):
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
recycling_probs
=
[
]
1.
/
(
max_iters
+
1
)
for
_
in
range
(
max_iters
+
1
)
keyed_probs
.
append
(
]
(
"no_recycling_iters"
,
recycling_probs
)
keyed_probs
.
append
(
)
(
"no_recycling_iters"
,
recycling_probs
)
)
else
:
else
:
recycling_probs
=
[
recycling_probs
=
[
0.
for
_
in
range
(
max_iters
+
1
)
0.
for
_
in
range
(
max_iters
+
1
)
...
...
train_openfold.py
View file @
e9d2d893
...
@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"loss"
,
loss
)
return
{
"loss"
:
loss
}
return
{
"loss"
:
loss
}
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
...
@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"val_loss"
,
loss
,
prog_bar
=
True
)
return
{
"val_loss"
:
loss
}
return
{
"val_loss"
:
loss
}
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment