"...http/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "88ad3425c4f8affd0fdc0431713f114c9f8058c3"
Commit f649cccd authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix cropping bug

parent 4ab59433
......@@ -1134,10 +1134,10 @@ def random_crop_to_size(
n = seq_length - num_res_crop_size
if batch_mode == "clamped":
right_anchor = n + 1
right_anchor = n
elif batch_mode == "unclamped":
x = _randint(0, n)
right_anchor = n - x + 1
right_anchor = n - x
else:
raise ValueError("Invalid batch mode")
......
......@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged."""
transforms = []
......@@ -117,7 +117,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=torch.Generator().seed(),
seed=ensemble_seed,
)
)
transforms.append(
......@@ -142,17 +142,23 @@ def process_tensors_from_config(
):
"""Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
batch_mode,
ensemble_seed,
)
fn = compose(fns)
d["ensemble_index"] = i
return fn(d)
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
......@@ -163,6 +169,7 @@ def process_tensors_from_config(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else:
tensors_0 = wrap_ensemble_fn(tensors, 0)
tensors = tree.map_structure(lambda x: x[None], tensors_0)
return tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment