"examples/python_rs/vscode:/vscode.git/clone" did not exist on "fe83f8aa3c96e238ef97275d5fec94b216d26743"
Commit e9d2d893 authored by Kolja Stahl's avatar Kolja Stahl
Browse files

val_loss fix and stop sampling recycling iterations in validation

parent 70d6bda5
...@@ -283,7 +283,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -283,7 +283,8 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs.append( keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob]) ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
) )
if(self.config.supervised.uniform_recycling):
if(self.stage == "train" and self.config.supervised.uniform_recycling):
recycling_probs = [ recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1) 1. / (max_iters + 1) for _ in range(max_iters + 1)
] ]
......
...@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("loss", loss)
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
...@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("val_loss", loss, prog_bar=True)
return {"val_loss": loss} return {"val_loss": loss}
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment