Commit f649cccd authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix cropping bug

parent 4ab59433
...@@ -1134,10 +1134,10 @@ def random_crop_to_size( ...@@ -1134,10 +1134,10 @@ def random_crop_to_size(
n = seq_length - num_res_crop_size n = seq_length - num_res_crop_size
if batch_mode == "clamped": if batch_mode == "clamped":
right_anchor = n + 1 right_anchor = n
elif batch_mode == "unclamped": elif batch_mode == "unclamped":
x = _randint(0, n) x = _randint(0, n)
right_anchor = n - x + 1 right_anchor = n - x
else: else:
raise ValueError("Invalid batch mode") raise ValueError("Invalid batch mode")
......
...@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -68,7 +68,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
return transforms return transforms
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode, ensemble_seed):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
transforms = [] transforms = []
...@@ -117,7 +117,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -117,7 +117,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
crop_feats, crop_feats,
mode_cfg.subsample_templates, mode_cfg.subsample_templates,
batch_mode=batch_mode, batch_mode=batch_mode,
seed=torch.Generator().seed(), seed=ensemble_seed,
) )
) )
transforms.append( transforms.append(
...@@ -142,17 +142,23 @@ def process_tensors_from_config( ...@@ -142,17 +142,23 @@ def process_tensors_from_config(
): ):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
ensemble_seed = torch.Generator().seed()
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension.""" """Function to be mapped over the ensemble dimension."""
d = data.copy() d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode) fns = ensembled_transform_fns(
common_cfg,
mode_cfg,
batch_mode,
ensemble_seed,
)
fn = compose(fns) fn = compose(fns)
d["ensemble_index"] = i d["ensemble_index"] = i
return fn(d) return fn(d)
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors) tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
if common_cfg.resample_msa_in_recycling: if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step. # Separate batch per ensembling & recycling step.
...@@ -163,6 +169,7 @@ def process_tensors_from_config( ...@@ -163,6 +169,7 @@ def process_tensors_from_config(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
) )
else: else:
tensors_0 = wrap_ensemble_fn(tensors, 0)
tensors = tree.map_structure(lambda x: x[None], tensors_0) tensors = tree.map_structure(lambda x: x[None], tensors_0)
return tensors return tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment