"...llm/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "494d56255a94f3546558bfe84c35d38b3ffcfed1"
Commit 08a5be86 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

update the merge label functions so that the merging dimension is not hard coded

parent 66a60d58
...@@ -1893,23 +1893,18 @@ def merge_labels(batch, per_asym_residue_index, labels, align): ...@@ -1893,23 +1893,18 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
cur_out = {} cur_out = {}
for i, j in align: for i, j in align:
label = labels[j][k] label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based # to 1-based
cur_residue_index = per_asym_residue_index[i + 1] cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)==0 or "template" in k: if len(v.shape)==0 or "template" in k:
continue continue
else: else:
cur_out[i] = label[cur_residue_index] dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())] cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0: if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=0) new_v = torch.concat(cur_out, dim=dimension_to_merge)
merged_nres = new_v.shape[0] print(f"k is {k} shape:{label.shape} and dimension_to_merge:{dimension_to_merge}")
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v outs[k] = new_v
print(f"finished merging") print(f"finished merging")
for k,v in outs.items(): for k,v in outs.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment