Commit ea7fcced authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

fixed merge_labels index error. Now working on cleaning up

parent fe01bb0c
......@@ -1830,7 +1830,6 @@ def get_least_asym_entity_or_longest_length(batch):
def greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
......@@ -1851,7 +1850,6 @@ def greedy_align(
best_rmsd = torch.inf
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
......@@ -1859,10 +1857,7 @@ def greedy_align(
if not used[j]: # possible candidate
cropped_pos = true_ca_poses[j]
cropped_pos = torch.squeeze(cropped_pos,0)
if not cropped_pos.shape==cur_pred_pos.shape:
# this means selected candidte is not the correct one. Skip
used[j] = True
else:
if cropped_pos.shape==cur_pred_pos.shape:
mask = true_ca_masks[j]
mask = torch.squeeze(mask,0)
print(f"cropped_pos shape: {cropped_pos.shape} cur_pred_pos shape: {cur_pred_pos.shape}")
......@@ -1871,9 +1866,11 @@ def greedy_align(
torch.squeeze(cropped_pos,0), torch.squeeze(cur_pred_pos,0),
(cur_pred_mask * mask).bool()
)
print(f"rmsd is {rmsd}")
if (rmsd is not None) and (rmsd < best_rmsd):
best_rmsd = rmsd
best_idx = j
print(f"best_idx is {best_idx}")
assert best_idx is not None
used[best_idx] = True
......@@ -1906,14 +1903,13 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue
else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
if k =='all_atom_positions':
dimension_to_merge=1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out[i] = label
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=dimension_to_merge)
......@@ -2138,7 +2134,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
gc.collect()
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment