Commit 08a5be86 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update the merge label functions so that the merging dimension is not hard coded

parent 66a60d58
...@@ -1893,23 +1893,18 @@ def merge_labels(batch, per_asym_residue_index, labels, align): ...@@ -1893,23 +1893,18 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
cur_out = {} cur_out = {}
for i, j in align: for i, j in align:
label = labels[j][k] label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)==0 or "template" in k: if len(v.shape)==0 or "template" in k:
continue continue
else: else:
cur_out[i] = label[cur_residue_index] dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=0) new_v = torch.concat(cur_out, dim=dimension_to_merge)
merged_nres = new_v.shape[0] print(f"k is {k} shape:{label.shape} and dimension_to_merge:{dimension_to_merge}")
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v outs[k] = new_v
print(f"finished merging") print(f"finished merging")
for k,v in outs.items(): for k,v in outs.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment