"git@developer.sourcefind.cn:OpenDAS/sparseconvnet.git" did not exist on "808dce4d45a37506d9e75b8582afe6877c8cfb0d"
Commit 2e3e51c8 authored by Kolja Stahl's avatar Kolja Stahl
Browse files

move uniform_recycling

parent 809a9861
...@@ -174,7 +174,6 @@ config = mlc.ConfigDict( ...@@ -174,7 +174,6 @@ config = mlc.ConfigDict(
}, },
"supervised": { "supervised": {
"clamp_prob": 0.9, "clamp_prob": 0.9,
"uniform_recycling": True,
"supervised_features": [ "supervised_features": [
"all_atom_mask", "all_atom_mask",
"all_atom_positions", "all_atom_positions",
...@@ -194,6 +193,7 @@ config = mlc.ConfigDict( ...@@ -194,6 +193,7 @@ config = mlc.ConfigDict(
"crop_size": None, "crop_size": None,
"supervised": False, "supervised": False,
"subsample_recycling": False, "subsample_recycling": False,
"uniform_recycling": False,
}, },
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
...@@ -206,6 +206,7 @@ config = mlc.ConfigDict( ...@@ -206,6 +206,7 @@ config = mlc.ConfigDict(
"crop_size": None, "crop_size": None,
"supervised": True, "supervised": True,
"subsample_recycling": False, "subsample_recycling": False,
"uniform_recycling": False,
}, },
"train": { "train": {
"fixed_size": True, "fixed_size": True,
...@@ -221,6 +222,7 @@ config = mlc.ConfigDict( ...@@ -221,6 +222,7 @@ config = mlc.ConfigDict(
"clamp_prob": 0.9, "clamp_prob": 0.9,
"subsample_recycling": True, "subsample_recycling": True,
"max_distillation_msa_clusters": 1000, "max_distillation_msa_clusters": 1000,
"uniform_recycling": True,
}, },
"data_module": { "data_module": {
"use_small_bfd": False, "use_small_bfd": False,
......
...@@ -284,21 +284,19 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -284,21 +284,19 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
("use_clamped_fape", [1 - clamp_prob, clamp_prob]) ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
) )
if(self.stage == "train" and self.config.supervised.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
1. / (max_iters + 1) for _ in range(max_iters + 1) 1. / (max_iters + 1) for _ in range(max_iters + 1)
] ]
keyed_probs.append(
("no_recycling_iters", recycling_probs)
)
else: else:
recycling_probs = [ recycling_probs = [
0. for _ in range(max_iters + 1) 0. for _ in range(max_iters + 1)
] ]
recycling_probs[-1] = 1. recycling_probs[-1] = 1.
keyed_probs.append(
("no_recycling_iters", recycling_probs) keyed_probs.append(
) ("no_recycling_iters", recycling_probs)
)
keys, probs = zip(*keyed_probs) keys, probs = zip(*keyed_probs)
max_len = max([len(p) for p in probs]) max_len = max([len(p) for p in probs])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment