Commit bc49758a authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update data_transforms_multimer to the latest version of multimer branch

parent da92663d
...@@ -347,7 +347,7 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator): ...@@ -347,7 +347,7 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
return get_contiguous_crop_idx(protein, crop_size, generator) return get_contiguous_crop_idx(protein, crop_size, generator)
target_res_idx = randint(lower=0, target_res_idx = randint(lower=0,
upper=interface_residues.shape[-1], upper=interface_residues.shape[-1] - 1,
generator=generator, generator=generator,
device=positions.device) device=positions.device)
...@@ -374,43 +374,45 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator): ...@@ -374,43 +374,45 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
def get_contiguous_crop_idx(protein, crop_size, generator): def get_contiguous_crop_idx(protein, crop_size, generator):
unique_asym_ids, chain_lens = protein["asym_id"].unique(return_counts=True) unique_asym_ids, chain_idxs, chain_lens = protein["asym_id"].unique(dim=-1,
return_inverse=True,
return_counts=True)
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator) shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
num_remaining = int(chain_lens.sum())
_, idx_sorted = torch.sort(chain_idxs, stable=True)
cum_sum = chain_lens.cumsum(dim=0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]), dim=0)
asym_offsets = idx_sorted[cum_sum]
num_budget = crop_size num_budget = crop_size
num_remaining = int(protein["seq_length"])
crop_idxs = [] crop_idxs = []
for idx in shuffle_idx:
chain_len = int(chain_lens[idx])
num_remaining -= chain_len
per_asym_residue_index = {} crop_size_max = min(num_budget, chain_len)
for cur_asym_id in unique_asym_ids: crop_size_min = min(chain_len, max(0, num_budget - num_remaining))
asym_mask = (protein["asym_id"]== cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(protein["asym_id"], asym_mask)[0]
for j, idx in enumerate(shuffle_idx):
this_len = int(chain_lens[idx])
num_remaining -= this_len
# num res at most we can keep in this ent
crop_size_max = min(num_budget, this_len)
# num res at least we shall keep in this ent
crop_size_min = min(this_len, max(0, num_budget - num_remaining))
chain_crop_size = randint(lower=crop_size_min, chain_crop_size = randint(lower=crop_size_min,
upper=crop_size_max + 1, upper=crop_size_max,
generator=generator, generator=generator,
device=chain_lens.device) device=chain_lens.device)
num_budget -= chain_crop_size
chain_start = randint(lower=0, chain_start = randint(lower=0,
upper=this_len - chain_crop_size + 1, upper=chain_len - chain_crop_size,
generator=generator, generator=generator,
device=chain_lens.device) device=chain_lens.device)
cur_asym_id = unique_asym_ids[int(idx)].item()
asym_offset = per_asym_residue_index[int(cur_asym_id)] asym_offset = asym_offsets[idx]
crop_idxs.append( crop_idxs.append(
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size) torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
) )
asym_offset += this_len
num_budget -= chain_crop_size
return torch.concat(crop_idxs) return torch.concat(crop_idxs).sort().values
@curry1 @curry1
...@@ -453,7 +455,7 @@ def random_crop_to_size( ...@@ -453,7 +455,7 @@ def random_crop_to_size(
if subsample_templates: if subsample_templates:
templates_crop_start = randint(lower=0, templates_crop_start = randint(lower=0,
upper=num_templates + 1, upper=num_templates,
generator=g, generator=g,
device=protein["seq_length"].device) device=protein["seq_length"].device)
templates_select_indices = torch.randperm( templates_select_indices = torch.randperm(
...@@ -480,8 +482,7 @@ def random_crop_to_size( ...@@ -480,8 +482,7 @@ def random_crop_to_size(
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"): if i == 0 and k.startswith("template"):
crop_start = templates_crop_start v = v[slice(templates_crop_start, templates_crop_start + num_templates_crop_size)]
v = v[slice(crop_start, crop_start + num_templates_crop_size)]
elif is_num_res: elif is_num_res:
v = torch.index_select(v, i, crop_idxs) v = torch.index_select(v, i, crop_idxs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment