Commit 2e3e51c8 authored by Kolja Stahl's avatar Kolja Stahl
Browse files

move uniform_recycling

parent 809a9861
......@@ -174,7 +174,6 @@ config = mlc.ConfigDict(
},
"supervised": {
"clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
......@@ -194,6 +193,7 @@ config = mlc.ConfigDict(
"crop_size": None,
"supervised": False,
"subsample_recycling": False,
"uniform_recycling": False,
},
"eval": {
"fixed_size": True,
......@@ -206,6 +206,7 @@ config = mlc.ConfigDict(
"crop_size": None,
"supervised": True,
"subsample_recycling": False,
"uniform_recycling": False,
},
"train": {
"fixed_size": True,
......@@ -221,6 +222,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9,
"subsample_recycling": True,
"max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
},
"data_module": {
"use_small_bfd": False,
......
......@@ -284,18 +284,16 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(self.stage == "train" and self.config.supervised.uniform_recycling):
if(stage_cfg.uniform_recycling):
recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1)
]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
else:
recycling_probs = [
0. for _ in range(max_iters + 1)
]
recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment