"...models/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "ac5ed37f92a6807d3ecc793dbf62e2bf0c960ef2"
Commit d65b76a2 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix inference-time bug in data pipeline, change config defaults

parent 13290527
......@@ -50,7 +50,7 @@ def model_config(name, train=False, low_prec=False):
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e5)
set_inf(c, 1e4)
return c
......@@ -185,7 +185,7 @@ config = mlc.ConfigDict(
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"num_ensemble": 1,
......@@ -197,7 +197,7 @@ config = mlc.ConfigDict(
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_msa_clusters": 128,
"max_template_hits": 4,
"max_templates": 4,
"num_ensemble": 1,
......@@ -209,7 +209,7 @@ config = mlc.ConfigDict(
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_msa_clusters": 128,
"max_template_hits": 20,
"max_templates": 4,
"num_ensemble": 1,
......
......@@ -156,19 +156,19 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
num_ensemble = mode_cfg.num_ensemble
num_recycling = int(tensors["no_recycling_iters"])
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= num_recycling + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else:
tensors_0 = wrap_ensemble_fn(tensors, 0)
tensors = tree.map_structure(lambda x: x[None], tensors_0)
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
return tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment