Commit 616de3df authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix for MSA block deletion

parent a64f1a29
...@@ -311,6 +311,11 @@ config = mlc.ConfigDict( ...@@ -311,6 +311,11 @@ config = mlc.ConfigDict(
"true_msa": [NUM_MSA_SEQ, NUM_RES], "true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [], "use_clamped_fape": [],
}, },
"block_delete_msa": {
"msa_fraction_per_block": 0.3,
"randomize_num_blocks": False,
"num_blocks": 5,
},
"masked_msa": { "masked_msa": {
"profile_prob": 0.1, "profile_prob": 0.1,
"same_prob": 0.1, "same_prob": 0.1,
...@@ -355,6 +360,7 @@ config = mlc.ConfigDict( ...@@ -355,6 +360,7 @@ config = mlc.ConfigDict(
"predict": { "predict": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512, "max_msa_clusters": 512,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -368,6 +374,7 @@ config = mlc.ConfigDict( ...@@ -368,6 +374,7 @@ config = mlc.ConfigDict(
"eval": { "eval": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": False, # We want top templates. "subsample_templates": False, # We want top templates.
"block_delete_msa": False,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024, "max_extra_msa": 1024,
...@@ -381,6 +388,7 @@ config = mlc.ConfigDict( ...@@ -381,6 +388,7 @@ config = mlc.ConfigDict(
"train": { "train": {
"fixed_size": True, "fixed_size": True,
"subsample_templates": True, "subsample_templates": True,
"block_delete_msa": True,
"masked_msa_replace_fraction": 0.15, "masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 128, "max_msa_clusters": 128,
"max_extra_msa": 1024, "max_extra_msa": 1024,
......
...@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): ...@@ -71,6 +71,9 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
transforms = [] transforms = []
if mode_cfg.block_delete_msa:
transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa))
if "max_distillation_msa_clusters" in mode_cfg: if "max_distillation_msa_clusters" in mode_cfg:
transforms.append( transforms.append(
data_transforms.sample_msa_distillation( data_transforms.sample_msa_distillation(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment