Unverified Commit c775cc12 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #58 from lhatsk/pr

update loss logging and stop sampling recycling iterations in validation
parents b45f6234 8ed52b70
......@@ -174,7 +174,6 @@ config = mlc.ConfigDict(
},
"supervised": {
"clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
......@@ -194,6 +193,7 @@ config = mlc.ConfigDict(
"crop_size": None,
"supervised": False,
"subsample_recycling": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
......@@ -206,6 +206,7 @@ config = mlc.ConfigDict(
"crop_size": None,
"supervised": True,
"subsample_recycling": False,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
......@@ -221,6 +222,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9,
"subsample_recycling": True,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
},
"data_module": {
"use_small_bfd": False,
......
......@@ -245,7 +245,7 @@ class OpenFoldDataset(torch.utils.data.IterableDataset):
class OpenFoldBatchCollator:
def __init__(self, config, generator, stage="train"):
def __init__(self, config, stage="train"):
self.stage = stage
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
......@@ -283,18 +283,17 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(self.config.supervised.uniform_recycling):
if(stage_cfg.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
......
......@@ -66,7 +66,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
loss = self.loss(outputs, batch)
self.log("loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
......@@ -79,6 +79,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss = self.loss(outputs, batch)
self.log("val_loss", loss, prog_bar=True)
return {"val_loss": loss}
def validation_epoch_end(self, _):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment