Commit 6713a264 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fixes for contiguous cropping

parent ab09ded4
......@@ -374,43 +374,45 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
def get_contiguous_crop_idx(protein, crop_size, generator):
unique_asym_ids, chain_lens = protein["asym_id"].unique(return_counts=True)
unique_asym_ids, chain_idxs, chain_lens = protein["asym_id"].unique(dim=-1,
return_inverse=True,
return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum())
_, idx_sorted = torch.sort(chain_idxs, stable=True)
cum_sum = chain_lens.cumsum(dim=0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]), dim=0)
asym_offsets = idx_sorted[cum_sum]
num_budget = crop_size
num_remaining = int(protein["seq_length"])
crop_idxs = []
for idx in shuffle_idx:
chain_len = int(chain_lens[idx])
num_remaining -= chain_len
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (protein["asym_id"]== cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(protein["asym_id"], asym_mask)[0]
for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx])
num_remaining -= this_len
# num res at most we can keep in this ent
crop_size_max = min(num_budget, this_len)
# num res at least we shall keep in this ent
crop_size_min = min(this_len, max(0, num_budget - num_remaining))
crop_size_max = min(num_budget, chain_len)
crop_size_min = min(chain_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max + 1,
upper=crop_size_max,
generator=generator,
device=chain_lens.device)
num_budget -= chain_crop_size
chain_start = randint(lower=0,
upper=this_len - chain_crop_size + 1,
upper=chain_len - chain_crop_size,
generator=generator,
device=chain_lens.device)
asym_offset = per_asym_residue_index[int(idx)]
asym_offset = asym_offsets[idx]
crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
)
asym_offset += this_len
num_budget -= chain_crop_size
return torch.concat(crop_idxs)
return torch.concat(crop_idxs).sort().values
@curry1
......@@ -453,7 +455,7 @@ def random_crop_to_size(
if subsample_templates:
templates_crop_start = randint(lower=0,
upper=num_templates + 1,
upper=num_templates,
generator=g,
device=protein["seq_length"].device)
templates_select_indices = torch.randperm(
......@@ -480,8 +482,7 @@ def random_crop_to_size(
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_start = templates_crop_start
v = v[slice(crop_start, crop_start + num_templates_crop_size)]
v = v[slice(templates_crop_start, templates_crop_start + num_templates_crop_size)]
elif is_num_res:
v = torch.index_select(v, i, crop_idxs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment