Unverified Commit c775cc12 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #58 from lhatsk/pr

update loss logging and stop sampling recycling iterations in validation
parents b45f6234 8ed52b70
...@@ -174,7 +174,6 @@ config = mlc.ConfigDict( ...@@ -174,7 +174,6 @@ config = mlc.ConfigDict(
}, },
"supervised": { "supervised": {
"clamp_prob": 0.9, "clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [ "supervised_features": [
"all_atom_mask", "all_atom_mask",
"all_atom_positions", "all_atom_positions",
...@@ -194,6 +193,7 @@ config = mlc.ConfigDict( ...@@ -194,6 +193,7 @@ config = mlc.ConfigDict(
"crop_size": None, "crop_size": None,
"supervised": False, "supervised": False,
"subsample_recycling": False, "subsample_recycling": False,
"uniform_recycling": False,
}, },
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
...@@ -206,6 +206,7 @@ config = mlc.ConfigDict( ...@@ -206,6 +206,7 @@ config = mlc.ConfigDict(
"crop_size": None, "crop_size": None,
"supervised": True, "supervised": True,
"subsample_recycling": False, "subsample_recycling": False,
"uniform_recycling": False,
}, },
"train": { "train": {
"fixed_size": True, "fixed_size": True,
...@@ -221,6 +222,7 @@ config = mlc.ConfigDict( ...@@ -221,6 +222,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9, "clamp_prob": 0.9,
"subsample_recycling": True, "subsample_recycling": True,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
......
...@@ -245,7 +245,7 @@ class OpenFoldDataset(torch.utils.data.IterableDataset): ...@@ -245,7 +245,7 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldBatchCollator: class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"): def __init__(self, config, stage="train"):
self.stage = stage self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config) self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
...@@ -283,18 +283,17 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -283,18 +283,17 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs.append( keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob]) ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
) )
if(self.config.supervised.uniform_recycling):
if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1) 1. / (max_iters + 1) for _ in range(max_iters + 1)
] ]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
else: else:
recycling_probs = [ recycling_probs = [
0. for _ in range(max_iters + 1) 0. for _ in range(max_iters + 1)
] ]
recycling_probs[-1] = 1. recycling_probs[-1] = 1.
keyed_probs.append( keyed_probs.append(
("no_recycling_iters", recycling_probs) ("no_recycling_iters", recycling_probs)
) )
......
...@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("loss", loss)
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
...@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("val_loss", loss, prog_bar=True)
return {"val_loss": loss} return {"val_loss": loss}
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment