Commit 66a60d58 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

start modifying mderge_label function to make it compatible with dataloader inputs

parent e9794a62
......@@ -1610,6 +1610,7 @@ def masked_msa_loss(logits, true_msa, bert_mask, num_classes, eps=1e-8, **kwargs
Returns:
Masked MSA loss
"""
print(f"logits shape: {logits.shape} true_msa shape:{true_msa.shape}")
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=num_classes)
)
......@@ -1749,7 +1750,7 @@ def get_optimal_transform(
src_atoms = torch.zeros((1, 3), device=src_atoms.device).float()
tgt_atoms = src_atoms
else:
src_atoms = src_atoms[mask, :]
src_atoms = src_atoms.to('cuda:0')[mask, :]
tgt_atoms = tgt_atoms.to('cuda:0')[mask, :]
src_center = src_atoms.mean(-2, keepdim=True)
tgt_center = tgt_atoms.mean(-2, keepdim=True)
......@@ -1857,7 +1858,6 @@ def greedy_align(
best_idx = None
cur_asym_list = entity_2_asym_list[int(cur_entity_ids)]
cur_residue_index = per_asym_residue_index[int(cur_asym_id)]
cur_pred_pos = pred_ca_pos[asym_mask]
cur_pred_mask = pred_ca_mask[asym_mask]
for next_asym_id in cur_asym_list:
......@@ -1890,27 +1890,30 @@ def merge_labels(batch, per_asym_residue_index, labels, align):
num_res = batch["msa_mask"].shape[-1]
outs = {}
for k, v in labels[0].items():
if k in [
"resolution",
]:
continue
cur_out = {}
for i, j in align:
label = labels[j][k]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
cur_out[i] = label[cur_residue_index]
if len(v.shape)==0 or "template" in k:
continue
else:
cur_out[i] = label[cur_residue_index]
cur_out = [x[1] for x in sorted(cur_out.items())]
new_v = torch.concat(cur_out, dim=0)
merged_nres = new_v.shape[0]
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v
if len(cur_out)>0:
new_v = torch.concat(cur_out, dim=0)
merged_nres = new_v.shape[0]
assert (
merged_nres <= num_res
), f"bad merged num res: {merged_nres} > {num_res}. something is wrong."
if merged_nres < num_res: # must pad
pad_dim = new_v.shape[1:]
pad_v = new_v.new_zeros((num_res - merged_nres, *pad_dim))
new_v = torch.concat((new_v, pad_v), dim=0)
outs[k] = new_v
print(f"finished merging")
for k,v in outs.items():
print(f"{k}:{v.shape}")
return outs
......@@ -2050,7 +2053,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].float() for l in labels
] # list([nres,])
unique_asym_ids = torch.unique(batch["asym_id"])
per_asym_residue_index = {}
......@@ -2059,7 +2061,6 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
per_asym_residue_index[int(cur_asym_id)] = batch["residue_index"][asym_mask]
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym is : {anchor_gt_asym} and anchor_pred_asym is {anchor_pred_asym}")
anchor_gt_idx = int(anchor_gt_asym) - 1
unique_entity_ids = torch.unique(batch["entity_id"])
......@@ -2100,15 +2101,13 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
del aligned_true_ca_poses
del r,x
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
merged_labels = merge_labels(
batch,
per_asym_residue_index,
labels,
align,
)
print(f"finished multi-chain permutation and final align is {align}")
return merged_labels
......@@ -2122,9 +2121,9 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
"""
features,labels = batch
features['resolution'] = labels[2]['resolution'] # firstly update the resolution feature
# first remove the recycling dimention of input features
features = tensor_tree_map(lambda t: t[..., -1], features)
features['resolution'] = labels[0]['resolution']
# then permutate ground truth chains before calculating the loss
permutated_labels = self.multi_chain_perm_align(out,features,labels)
permutated_labels.pop('aatype')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment